diff --git a/.buildbot/Jenkinsfile b/.buildbot/Jenkinsfile index f9ce3a8bc7..e63fb6de3d 100644 --- a/.buildbot/Jenkinsfile +++ b/.buildbot/Jenkinsfile @@ -26,6 +26,7 @@ def call() { skipDefaultCheckout() buildDiscarder(logRotator(numToKeepStr:'1000')) ansiColor('xterm') + timeout(time: 6, unit: 'HOURS') } agent { diff --git a/docker/Dockerfile.onnx-mlir b/docker/Dockerfile.onnx-mlir index 1006f5052c..a9f8688e9a 100644 --- a/docker/Dockerfile.onnx-mlir +++ b/docker/Dockerfile.onnx-mlir @@ -43,6 +43,7 @@ RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \ && CC=clang CXX=clang++ \ cmake -DMLIR_DIR=${LLVM_PROJECT_ROOT}/build/lib/cmake/mlir \ -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ -DCMAKE_INSTALL_MESSAGE=NEVER \ -DONNX_MLIR_ACCELERATORS=${ACCEL} .. \ && make -j${NPROC} \ diff --git a/docs/BuildOnLinuxOSX.md b/docs/BuildOnLinuxOSX.md index 1f0c5cf4d0..eb2014d748 100644 --- a/docs/BuildOnLinuxOSX.md +++ b/docs/BuildOnLinuxOSX.md @@ -15,25 +15,38 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout 6461b921fd06b1c812f1172685b8b7edc0608af7 && cd .. +cd llvm-project && git checkout 60a7d33106d3cd645d3100a8a935a1e3837f885d && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) ``` bash mkdir llvm-project/build cd llvm-project/build + cmake -G Ninja ../llvm \ - -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \ -DLLVM_TARGETS_TO_BUILD="host" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_RTTI=ON \ + -DENABLE_LIBOMPTARGET=OFF \ -DLLVM_ENABLE_LIBEDIT=OFF cmake --build . -- ${MAKEFLAGS} cmake --build . --target check-mlir ``` +To enable parallelization for onnx-mlir, llvm-project should be configured as +``` +cmake -G Ninja ../llvm \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_TARGETS_TO_BUILD="host" \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_ENABLE_LIBEDIT=OFF +``` + ## ONNX-MLIR (this project) ### Build @@ -54,11 +67,15 @@ mkdir onnx-mlir/build && cd onnx-mlir/build if [[ -z "$pythonLocation" ]]; then cmake -G Ninja \ -DCMAKE_CXX_COMPILER=/usr/bin/c++ \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_DIR=${MLIR_DIR} \ .. else cmake -G Ninja \ -DCMAKE_CXX_COMPILER=/usr/bin/c++ \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ -DPython3_ROOT_DIR=$pythonLocation \ -DMLIR_DIR=${MLIR_DIR} \ .. diff --git a/docs/BuildOnWindows.md b/docs/BuildOnWindows.md index 2991d6f2b7..77650910c1 100644 --- a/docs/BuildOnWindows.md +++ b/docs/BuildOnWindows.md @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout 6461b921fd06b1c812f1172685b8b7edc0608af7 && cd .. +cd llvm-project && git checkout 60a7d33106d3cd645d3100a8a935a1e3837f885d && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) @@ -62,13 +62,14 @@ md llvm-project\build cd llvm-project\build call cmake %root_dir%\llvm-project\llvm -G "Ninja" ^ -DCMAKE_INSTALL_PREFIX="%root_dir%\llvm-project\build\install" ^ - -DLLVM_ENABLE_PROJECTS=mlir ^ + -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" ^ -DLLVM_TARGETS_TO_BUILD="host" ^ -DCMAKE_BUILD_TYPE=Release ^ -DLLVM_ENABLE_ASSERTIONS=ON ^ -DLLVM_ENABLE_RTTI=ON ^ -DLLVM_ENABLE_ZLIB=OFF ^ -DLLVM_INSTALL_UTILS=ON ^ + -DENABLE_LIBOMPTARGET=OFF ^ -DLLVM_ENABLE_LIBEDIT=OFF call cmake --build . --config Release diff --git a/src/Accelerators/Accelerator.hpp b/src/Accelerators/Accelerator.hpp index 2044789127..5c2b47187e 100644 --- a/src/Accelerators/Accelerator.hpp +++ b/src/Accelerators/Accelerator.hpp @@ -4,7 +4,7 @@ //===-------------------------- Accelerator.hpp ---------------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ACCELERATOR_H +#define ONNX_MLIR_ACCELERATOR_H #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/IR/BuiltinOps.h" @@ -32,14 +33,14 @@ #define CREATE_ACCEL_ENUM(name) name, #define DECLARE_ACCEL_INIT_FUNCTION(name) extern Accelerator *create##name(); #define INVOKE_ACCEL_INIT_FUNCTION(name, kinds) \ - if (!kinds.empty() && \ + if (!(kinds).empty() && \ llvm::is_contained(kinds, accel::Accelerator::Kind::name)) \ create##name()->setName(#name); #define CREATE_ACCEL_CL_ENUM(name) \ clEnumValN(accel::Accelerator::Kind::name, #name, #name " accelerator"), #define ACCEL_CL_ENUM_FROM_STRING(name, var, str) \ - if (str.compare(std::string(#name)) == 0) { \ - var = accel::Accelerator::Kind::name; \ + if ((str).compare(std::string(#name)) == 0) { \ + (var) = accel::Accelerator::Kind::name; \ return true; \ } #define ACCEL_CL_ENUM_TO_STRING(name, map) \ @@ -164,3 +165,4 @@ extern void initAccelerators(llvm::ArrayRef kinds); } // namespace accel } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp index 7b70ba6872..ee4e0ae363 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp @@ -29,7 +29,8 @@ llvm::cl::opt nnpaEmissionTarget( llvm::cl::opt nnpaClipToDLFloatRange("nnpa-clip-to-dlfloat-range", llvm::cl::desc("Clip CPU tensors to dlfloat range before stickification to " "avoid out-of-range. Only clip Softmax inputs at this " - "moment. Default is true."), + "moment. Default is true. This option will be removed and " + "replaced by --nnpa-saturation in the future."), llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::opt nnpaEnableZHighToOnnx("enable-zhigh-to-onnx", @@ -49,11 +50,13 @@ llvm::cl::opt nnpaEnableZHighDecomposeStickUnstick( "Default is false."), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); +// Enabled default now, could also enable it only if parallel is on as parallel +// stick/unstick is quite a bit faster than sequential. llvm::cl::opt nnpaEnableCompilerStickUnstick( "enable-compiler-stick-unstick", llvm::cl::desc("[Experimental feature] Enable the compiler generate some " - "stick/unstick code. Default is false."), - llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); + "stick/unstick code. Default is true."), + llvm::cl::init(true), llvm::cl::cat(OnnxMlirCommonOptions)); llvm::cl::opt nnpaEnableScalarBcastBinary( "nnpa-enable-scalar-bcast-binary", @@ -93,6 +96,7 @@ llvm::cl::opt nnpaPlacementHeuristic{ llvm::cl::opt nnpaEnableSaturation("nnpa-saturation", llvm::cl::desc("Enable saturating f32 values before stickify them." + "This option turns enable-compiler-stick-unstick on." "Default is false."), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp index 8235709569..2b0343295c 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp @@ -4,14 +4,15 @@ //===------------------------ NNPACompilerOptions.hpp ---------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_NNPA_COMPILER_OPTIONS_H +#define ONNX_MLIR_NNPA_COMPILER_OPTIONS_H #include "llvm/Support/CommandLine.h" @@ -69,3 +70,4 @@ extern llvm::cl::opt nnpaSaveDevicePlacementFile; extern llvm::cl::opt nnpaEnableSaturation; } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 8f8bbc1f92..2d411da967 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -48,6 +48,15 @@ namespace onnx_mlir { void configurePassesNNPA() { configureOnnxToZHighLoweringPass(optReport == OptReport::NNPAUnsupportedOps); + // Compiler generated sticks supports saturation, so force its usage. + // TODO: remove this if zDNN adds support for saturation. + if (nnpaEnableSaturation) + nnpaEnableCompilerStickUnstick = true; + // Currently nnpaEnableCompilerStickUnstick not supported on zOS. + // TODO enable on zOS + if (mtriple == "s390x-ibm-zos") { + nnpaEnableCompilerStickUnstick = false; + } } void addONNXToZHighPasses(mlir::PassManager &pm) { @@ -94,7 +103,8 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { // Clip zhigh.Stick inputs if required. This is to avoid out-of-range of // dlfloat. Do constant propagation after clipping to remove ONNX ops used for // clipping such as ONNXMax if applicable. - if (nnpaClipToDLFloatRange) { + // This pass will be removed and replaced by nnpa-saturation in the future. + if (!nnpaEnableSaturation && nnpaClipToDLFloatRange) { pm.addNestedPass( onnx_mlir::zhigh::createZHighClipToDLFloatPass()); pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); @@ -214,8 +224,8 @@ void addPassesNNPA(mlir::OwningOpRef &module, else if (optStr == "-O3") optLevel = OptLevel::O3; // Lower ONNX to Krnl, ZHigh to ZLow. - addONNXToKrnlPasses(pm, optLevel, /*enableCSE*/ true, - instrumentONNXSignature, ONNXOpStats); + addONNXToKrnlPasses( + pm, optLevel, /*enableCSE*/ true, instrumentSignatures, ONNXOpStats); if (nnpaEmissionTarget >= EmitZLowIR) emissionTarget = EmitMLIR; diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.hpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.hpp index a4069dd9d6..3b7aec5652 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.hpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.hpp @@ -4,14 +4,15 @@ //===------------------------- NNPACompilerUtils.hpp ----------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_NNPA_COMPILER_UTILS_H +#define ONNX_MLIR_NNPA_COMPILER_UTILS_H #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" @@ -37,3 +38,4 @@ void addPassesNNPA(mlir::OwningOpRef &module, void configurePassesNNPA(); } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacementHeuristic.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacementHeuristic.hpp index d03647fcd7..f3ab214c91 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacementHeuristic.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacementHeuristic.hpp @@ -4,7 +4,7 @@ //===-------- DevicePlacementHeuristic.hpp - Place ops using model -------===// // -// Copyright 2023 The IBM Research Authors. +// Copyright 2023-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_HEURISTICS_H +#define ONNX_MLIR_HEURISTICS_H #include "mlir/IR/BuiltinOps.h" @@ -85,3 +86,4 @@ void PlaceBeneficialOpsOnNNPAWithStickUnstick(mlir::MLIRContext *context, double significantNNPAFactor = 3.0); } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp index f4771b0007..f9c36372c4 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp @@ -4,7 +4,7 @@ //===---------- ONNXLegalityCheck.hpp - Check legality for ONNX ops -------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -14,7 +14,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_LEGALITY_H +#define ONNX_MLIR_LEGALITY_H #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" @@ -53,3 +54,5 @@ bool onnxToZHighUnsupportedReport( bool onnxToZHighInCompatibilityReport( mlir::Operation *op, std::string inputNNPALevel); + +#endif diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp index 334600a8aa..034f92a6e3 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp @@ -4,7 +4,7 @@ //====------ ONNXToZHigh.hpp - ONNX dialect to ZHigh lowering -------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_TO_ZHIGH_H +#define ONNX_MLIR_ONNX_TO_ZHIGH_H #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -30,3 +31,4 @@ void getONNXToZHighOneOpDynamicallyLegal( mlir::ConversionTarget *target, const DimAnalysis *dimAnalysis); } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp index 52efabf902..e8e68a0e37 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp @@ -2,7 +2,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===---------- ONNXToZHigh.hpp - Common functions in ONNXToZHigh ---------===// +//===---------- ONNXToZHighCommon.hpp - Common functions in ONNXToZHigh +//---------===// // // Copyright 2019-2024 The IBM Research Authors. // @@ -12,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ZHIGH_COMMON_H +#define ONNX_MLIR_ZHIGH_COMMON_H #include "llvm/ADT/STLExtras.h" @@ -115,3 +117,4 @@ mlir::Value getDynShape( mlir::Location loc, mlir::PatternRewriter &rewriter, mlir::Value x); } // namespace onnx_mlir +#endif \ No newline at end of file diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.hpp index 1b00220ee7..238796ad4e 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.hpp @@ -4,7 +4,7 @@ //===-------- PerfModel.hpp - Estimate if CPU or NNPA is faster ----------===// // -// Copyright 2023 The IBM Research Authors. +// Copyright 2023-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_PERF_H +#define ONNX_MLIR_PERF_H #include "mlir/IR/BuiltinOps.h" @@ -32,3 +33,4 @@ double estimateTimeForStickOp(mlir::Value oper); double estimateTimeForUnstickOp(mlir::Value oper); } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp index 4ab40101ac..3164c93b4e 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp @@ -4,14 +4,15 @@ //===--- RewriteONNXForZHigh.hpp - Rewrite ONNX ops for ZHigh lowering ----===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // // This file implements pass for rewriting of ONNX operations to generate // combination of ONNX and ZHigh operations. -#pragma once +#ifndef ONNX_MLIR_REWRITE_ZHIGH_H +#define ONNX_MLIR_REWRITE_ZHIGH_H #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,3 +29,4 @@ void getRewriteONNXForZHighDynamicallyLegal( mlir::ConversionTarget *target, const DimAnalysis *dimAnalysis); } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp index f6302f827c..0850227bef 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp @@ -140,17 +140,17 @@ static Value insertAllocForWorkAreaForRNNOps(IndexExprBuilderForKrnl &createIE, createIE.getShapeAsDims(rnnHiddenWeight, hiddenWeightDims); IndexExpr timestepExp = inputDims[0]; - IndexExpr Lit2 = LiteralIndexExpr(2); - IndexExpr NumOfGatesLit = LiteralIndexExpr(numOfGates); + IndexExpr Lit2 = LitIE(2); + IndexExpr NumOfGatesLit = LitIE(numOfGates); IndexExpr dim1 = hiddenWeightDims[1]; IndexExpr dim2 = inputDims[1]; - IndexExpr dim3 = LiteralIndexExpr(1); + IndexExpr dim3 = LitIE(1); IndexExpr dim4 = NumOfGatesLit * timestepExp + NumOfGatesLit + Lit2; - IndexExpr Lit1 = LiteralIndexExpr(1); - IndexExpr Lit32 = LiteralIndexExpr(32); - IndexExpr Lit64 = LiteralIndexExpr(64); - IndexExpr Lit4K = LiteralIndexExpr(4096); + IndexExpr Lit1 = LitIE(1); + IndexExpr Lit32 = LitIE(32); + IndexExpr Lit64 = LitIE(64); + IndexExpr Lit4K = LitIE(4096); IndexExpr ceilDim2 = (dim2 + Lit32 - Lit1).floorDiv(Lit32); IndexExpr ceilDim1 = (dim1 + Lit64 - Lit1).floorDiv(Lit64); IndexExpr sizeExpr = dim4 * dim3 * ceilDim2 * ceilDim1 * Lit4K; @@ -217,7 +217,7 @@ Value insertShapeMemRefI64( for (uint64_t i = 0; i < originalDims.size(); ++i) { Value dim = create.math.cast(rewriter.getI64Type(), originalDims[i].getValue()); - create.krnl.storeIE(dim, shapeMemRef, {LiteralIndexExpr(i)}); + create.krnl.storeIE(dim, shapeMemRef, {LitIE(i)}); } return shapeMemRef; } @@ -492,8 +492,9 @@ struct ZHighToZLowStickOpLowering : public ConversionPattern { StringAttr layout = stickOp.getLayoutAttr(); IntegerAttr saturation = stickOp.getSaturationAttr(); - IndexExprBuilderForKrnl createKrnlIE(rewriter, loc); - ZHighStickOpShapeHelper shapeHelper(op, operands, &createKrnlIE); + MultiDialectBuilder create( + rewriter, loc); + ZHighStickOpShapeHelper shapeHelper(op, operands, &create.krnlIE); shapeHelper.computeShapeAndAssertOnFailure(); // Convert ZTensor type to MemRefType. @@ -503,9 +504,17 @@ struct ZHighToZLowStickOpLowering : public ConversionPattern { // Allocate a buffer for the result MemRef. Value alloc = insertAllocForZMemRef( zMemRefType, shapeHelper.getOutputDims(), op, rewriter); - // Set pre-transformed layout: if NHWC, we can directly stickify from NCHW. - if (isNHWCLayout(layout)) - layout = getNCHWLayoutAttr(rewriter); + if (isNHWCLayout(layout)) { + if (nnpaEnableCompilerStickUnstick) { + // Compiler-generated stick hasn't supported NCHW yet. + // Explicitly transpose NCHW to NHWC. + input = create.onnx.toMemref( + create.onnx.transposeInt64(input, ArrayRef({0, 2, 3, 1}))); + } else + // Otherwise, we can directly stickify from NCHW. + // Set pre-transformed layout to NCHW. + layout = getNCHWLayoutAttr(rewriter); + } // Else, emit a ZLow operation. rewriter.create(loc, input, alloc, layout, saturation); @@ -610,8 +619,9 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern { StringAttr layout = getZTensorLayoutAttr(rewriter, op->getOperand(0).getType()); - IndexExprBuilderForKrnl createKrnlIE(rewriter, loc); - ZHighUnstickOpShapeHelper shapeHelper(op, operands, &createKrnlIE); + MultiDialectBuilder create( + rewriter, loc); + ZHighUnstickOpShapeHelper shapeHelper(op, operands, &create.krnlIE); shapeHelper.computeShapeAndAssertOnFailure(); // Convert ZTensor type to MemRefType. @@ -619,15 +629,40 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern { convertZTensorToMemRefType(*op->result_type_begin()); // Allocate a buffer for the result MemRef. - Value alloc = insertAllocForZMemRef( - zMemRefType, shapeHelper.getOutputDims(), op, rewriter); - - // Set layout: if NHWC, we can directly unstickify to NCHW. - if (isNHWCLayout(layout)) - layout = getNCHWLayoutAttr(rewriter); + Value alloc = nullptr; + if (isNHWCLayout(layout)) { + if (nnpaEnableCompilerStickUnstick) { + // Compiler-generated unstick hasn't supported NCHW yet. + // This code allocates a NHWC buffer. It gets dims from the NCHW input. + SmallVector dimList; + dimList.emplace_back(shapeHelper.getOutputDims()[0]); + dimList.emplace_back(shapeHelper.getOutputDims()[2]); + dimList.emplace_back(shapeHelper.getOutputDims()[3]); + dimList.emplace_back(shapeHelper.getOutputDims()[1]); + MultiDialectBuilder create(rewriter, loc); + MemRefType resType = zMemRefType.value; + ArrayRef shape = resType.getShape(); + alloc = create.mem.alignedAlloc( + MemRefType::get({shape[0], shape[2], shape[3], shape[1]}, + resType.getElementType()), + dimList); + } else { + // Otherwise, we can directly stickify from NCHW. + // Set pre-transformed layout to NCHW. + layout = getNCHWLayoutAttr(rewriter); + } + } + if (alloc == nullptr) + alloc = insertAllocForZMemRef( + zMemRefType, shapeHelper.getOutputDims(), op, rewriter); // Emit a ZLow operation. rewriter.create(loc, input, alloc, layout); + if (isNHWCLayout(layout) && nnpaEnableCompilerStickUnstick) + // Compiler-generated unstick hasn't supported NCHW yet. + // Explicitly transpose NHWC to NCHW. + alloc = + create.onnx.transposeInt64(alloc, ArrayRef({0, 3, 1, 2})); rewriter.replaceOp(op, alloc); return success(); } @@ -1549,7 +1584,7 @@ struct ZHighToZLowStickifiedConstantOfShapeOpLowering // inside the loop, and LLVM does not seem to read the f16 value. uint64_t rank = mlir::cast(res.getType()).getRank(); ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); SmallVector ubs = shapeHelper.getOutputDims(); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, [&](KrnlBuilder &createKrnl, ValueRange indices) { @@ -1594,15 +1629,15 @@ struct ZHighToZLowDataConversionLowering int64_t rank = getRank(X.getType()); // SIMD info. - // Fixed VL for the conversion instruction: 8 elements per instruction call. - // Because the VL of the zlow.conversions are not "virtualized" in lengths, - // we manually unroll the loop containing the SIMD operations manually. + // Fixed VL for the conversion instruction: 8 elements per instruction + // call. Because the VL of the zlow.conversions are not "virtualized" in + // lengths, we manually unroll the loop containing the SIMD operations. // Experiments on a 1024x1024 tensors shows best results with an unrolling // of 8 SIMD vectors. - int64_t VL = 8; - int64_t VLHalf = VL / 2; - int64_t unrollSIMD = 8; // Manually unroll the SIMD loop. - int64_t unrollVL = unrollSIMD * VL; // Total numbers of values unrolled. + int64_t archVL = 8; // Vector length as defined by z arch for this type. + int64_t archVLHalf = archVL / 2; + int64_t unrollVL = 8; // Manually unroll the SIMD loop. + int64_t totVL = unrollVL * archVL; // Total numbers of values unrolled. // Convert the output type to MemRef. Type outputTensorType = convertOp.getResult().getType(); @@ -1612,24 +1647,23 @@ struct ZHighToZLowDataConversionLowering assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - // Types use the SIMD unrolling VL and VLHalf. + // Types use archVL and archVL / 2. Type f16Type = rewriter.getF16Type(); Type f32Type = rewriter.getF32Type(); - VectorType vecF16Type = VectorType::get({VL}, f16Type); - VectorType vecF32Type = VectorType::get({VLHalf}, f32Type); + VectorType vecF16Type = VectorType::get({archVL}, f16Type); + VectorType vecF32Type = VectorType::get({archVLHalf}, f32Type); // Compute output dims. DimsExpr outputDims; ONNXUnaryOpShapeHelper shapeHelper(op, operands, &create.krnlIE); shapeHelper.computeShapeAndAssertOnFailure(); IndexExprScope allocScope(create.vec, shapeHelper.getScope()); - getIndexExprList(shapeHelper.getOutputDims(), outputDims); + getIndexExprList(shapeHelper.getOutputDims(), outputDims); - // Alloc memory with padding for SIMD. Padding and loop unrolling use - // unrollVL. + // Alloc memory with padding for SIMD using totVL. MemRefType outputMemRefType = mlir::cast(convertedType); Value alloc = create.mem.alignedAllocWithSimdPadding( - outputMemRefType, outputDims, unrollVL, alignment); + outputMemRefType, outputDims, totVL, alignment); // Flatten the input to 1D. int64_t collapsedInnermostLoops = rank; @@ -1642,18 +1676,18 @@ struct ZHighToZLowDataConversionLowering SmallVector flattenedOutputDims; Value flatOutput = create.mem.reshapeToFlatInnermost( alloc, outputDims, flattenedOutputDims, collapsedInnermostLoops); - DimsExpr lbs(1, LiteralIndexExpr(0)); + DimsExpr lbs(1, LitIE(0)); - // Create loop iteration (flattened to 1D) and block it by unrollVL. + // Create loop iteration (flattened to 1D) and block it by totVL. ValueRange loopDef = create.krnl.defineLoops(1); - ValueRange blockedLoopDef = create.krnl.block(loopDef[0], unrollVL); + ValueRange blockedLoopDef = create.krnl.block(loopDef[0], totVL); SmallVector optimizedLoopDef(1, blockedLoopDef[0]); if (enableParallel) { int64_t parId; int64_t tripCount = flattenedOutputDims[0].isLiteral() - ? std::ceil(flattenedOutputDims[0].getLiteral() / (float)VL) + ? std::ceil(flattenedOutputDims[0].getLiteral() / (float)archVL) : -1; if (findSuitableParallelDimension(lbs, flattenedOutputDims, 0, 1, parId, /*min iter for going parallel*/ 1024)) { @@ -1665,7 +1699,7 @@ struct ZHighToZLowDataConversionLowering "not enough work for dlf16-f32 conversion"); } } - onnxToKrnlSimdReport(op, /*successful*/ true, VL, + onnxToKrnlSimdReport(op, /*successful*/ true, archVL, flattenedOutputDims[0].isLiteral() ? flattenedOutputDims[0].getLiteral() : -1, "dlf16-f32 conversion fully flattened"); @@ -1673,26 +1707,26 @@ struct ZHighToZLowDataConversionLowering create.krnl.iterateIE(loopDef, optimizedLoopDef, lbs, flattenedOutputDims, [&](KrnlBuilder &b, ValueRange loopInd) { MDBuilder create(b); - // Manually unrolled loop, add VL offset at each iterations. - for (int64_t u = 0; u < unrollSIMD; ++u) { - Value baseIdx = - create.math.add(loopInd[0], create.math.constantIndex(u * VL)); + // Manually unrolled loop, add archVL offset at each iterations. + for (int64_t u = 0; u < unrollVL; ++u) { + Value baseIdx = create.math.add( + loopInd[0], create.math.constantIndex(u * archVL)); Value baseIdxNext = - create.math.add(baseIdx, create.math.constantIndex(VLHalf)); + create.math.add(baseIdx, create.math.constantIndex(archVLHalf)); if (fromF32) { // F32 -> DLF16 - // Load VL f32 values from the input into two vectors each - // with VLHalf f32 values. + // Load archVL f32 values from the input into two vectors each + // with archVLHalf f32 values. Value vecF32H = create.vec.load(vecF32Type, flatInput, {baseIdx}); Value vecF32L = create.vec.load(vecF32Type, flatInput, {baseIdxNext}); Value vecF16 = rewriter.create( loc, vecF32H, vecF32L); - // Store VL f16 values back to the output. + // Store archVL f16 values back to the output. create.vec.store(vecF16, flatOutput, {baseIdx}); } else { // DLF16 -> F32 - // Load VL f16 values from the input into a register. + // Load archVL f16 values from the input into a register. Value vecF16 = create.vec.load(vecF16Type, flatInput, {baseIdx}); auto convertOp = rewriter.create(loc, vecF16); diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp index 5529c132b5..021f47deb3 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp @@ -4,7 +4,7 @@ //====------ ZHighToZLow.hpp - ZHigh dialect to ZLow lowering -------------===// // -// Copyright 2019-2021 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ZHIGH_TO_ZLOW_H +#define ONNX_MLIR_ZHIGH_TO_ZLOW_H #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -57,3 +58,4 @@ void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns, } // namespace zhigh } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.hpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.hpp index be14fdfd31..445daa4fa8 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.hpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.hpp @@ -4,7 +4,7 @@ //===---------- ZLowToLLVM.hpp - Lowering from ZLow to LLVM ---------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ZLOW_TO_LLVM_H +#define ONNX_MLIR_ZLOW_TO_LLVM_H #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/IR/PatternMatch.h" @@ -27,3 +28,4 @@ void populateZLowToLLVMConversionPattern(mlir::RewritePatternSet &patterns, } // namespace zlow } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp index 8e4cff4574..e253389293 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp @@ -287,8 +287,8 @@ Value callApi(PatternRewriter &rewriter, Location loc, ModuleOp module, Type outputTy = apiSpec.outputTy; if (!mlir::isa(outputTy)) outputTys.emplace_back(outputTy); - return create.llvm.call( - ArrayRef(outputTys), symbolRef, ArrayRef(params)); + return create.llvm.call(ArrayRef(outputTys), symbolRef, + ArrayRef(params), apiSpec.isVarArg); } size_t getRankFromMemRefType(LLVM::LLVMStructType memRefTy) { diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp index 6dc9b8c65e..9e9c251b73 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp @@ -4,7 +4,7 @@ //===---------- ZLowToLLVMCommon.hpp - Lowering from ZLow to LLVM ---------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ZLOW_TO_LLVM_COMMON_H +#define ONNX_MLIR_ZLOW_TO_LLVM_COMMON_H #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -200,3 +201,4 @@ void fillInZTensor(mlir::PatternRewriter &rewriter, mlir::Location loc, } // namespace zlow } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp index 8c83285353..affe053c9a 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp @@ -4,7 +4,7 @@ //===------------------ ZHighOps.hpp - ZHigh Operations -------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ZHIGH_H +#define ONNX_MLIR_ZHIGH_H #include #include @@ -58,3 +59,4 @@ class SameOperandsAndResultLayout #define GET_OP_CLASSES #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp.inc" +#endif \ No newline at end of file diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp index cc2479d53d..def0813d7b 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp @@ -4,13 +4,14 @@ //===-------- ZHighHelper.hpp - ZHigh Helper Functions --------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_OP_HELPER_H +#define ONNX_MLIR_OP_HELPER_H #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -89,3 +90,4 @@ mlir::IntegerAttr getDefaultSaturation(mlir::PatternRewriter &rewriter); } // namespace zhigh } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp index e4954f0d74..cb8194f408 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp @@ -4,7 +4,7 @@ //===----------------ShapeHelper.hpp - shape helpers for ZHigh ------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ZHIGH_SHAPE_HELPER_H +#define ONNX_MLIR_ZHIGH_SHAPE_HELPER_H #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -174,3 +175,4 @@ using ZHighFixGRUYOpShapeHelper = ONNXUnaryOpShapeHelper; } // namespace zhigh } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp b/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp index bef9e3b3a6..b3310f0373 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp +++ b/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp @@ -4,7 +4,7 @@ //====--------- DialectBuilder.hpp - ZLow Dialect Builder -----------------===// // -// Copyright 2022-2023 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DIALECT_BUILDER_H +#define ONNX_MLIR_DIALECT_BUILDER_H #include "src/Dialect/Mlir/DialectBuilder.hpp" #include "src/Dialect/Mlir/IndexExprBuilder.hpp" @@ -53,3 +54,4 @@ struct MultiDialectBuilder }; } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp b/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp index d7c9541543..541130b516 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp @@ -4,7 +4,7 @@ //===------------------ ZLowOps.hpp - ZLow Operations ---------------------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ZLOW_H +#define ONNX_MLIR_ZLOW_H #include #include @@ -31,3 +32,4 @@ #define GET_OP_CLASSES #include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp.inc" +#endif diff --git a/src/Accelerators/NNPA/NNPAAccelerator.hpp b/src/Accelerators/NNPA/NNPAAccelerator.hpp index 5721beb94f..e40bd774b6 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.hpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.hpp @@ -4,7 +4,7 @@ //===-------------------------- NNPAAccelerator.hpp ----------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // =========================================================================== // @@ -12,7 +12,8 @@ // //===---------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_NNPA_ACCELERATOR_H +#define ONNX_MLIR_NNPA_ACCELERATOR_H #include "mlir/IR/BuiltinTypes.h" #include "src/Accelerators/Accelerator.hpp" @@ -79,3 +80,4 @@ class NNPAAccelerator final : public Accelerator { } // namespace accel } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index 7ffe8843fa..b15e7f165d 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -4,7 +4,7 @@ //===---------- NNPAPasses.hpp - NNPA Passes Definition ------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_NNPA_PASSES_H +#define ONNX_MLIR_NNPA_PASSES_H #include "mlir/Pass/Pass.h" @@ -72,3 +73,4 @@ std::unique_ptr createZLowDummyOpForMultiDerefPass(); } // namespace zlow } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.c b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.c index 6986a4bea4..7f53cfd8b0 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.c +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.c @@ -116,8 +116,6 @@ void checkStatus(zdnn_status status, const char *zdnn_name) { } } -#define CHECK_ZDNN_STATUS(status, zdnn_name) checkStatus(status, zdnn_name) - void getUnmappedShape(const zdnn_ztensor *t, UnmappedShape *shape) { const zdnn_tensor_desc *desc = t->transformed_desc; shape->e4 = desc->dim4; diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h index 70110501c5..3123dd2e76 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h @@ -58,7 +58,7 @@ extern bool OMStatusMessagesEnabled; // Misc Macros // ----------------------------------------------------------------------------- -#define CEIL(a, b) (uint64_t)((a + b - 1) / b) // positive numbers only +#define CEIL(a, b) (uint64_t)(((a) + (b)-1) / (b)) // positive numbers only // ----------------------------------------------------------------------------- // Common structures @@ -159,7 +159,7 @@ inline void omUnreachable() { */ void checkStatus(zdnn_status status, const char *zdnn_name); -#define CHECK_ZDNN_STATUS(status, zdnn_name) checkStatus(status, zdnn_name) +#define CHECK_ZDNN_STATUS(status, zdnn_name) checkStatus((status), (zdnn_name)) /** * \brief Get the unmapped shape (4D) of ztensor. diff --git a/src/Accelerators/NNPA/Support/LayoutHelper.hpp b/src/Accelerators/NNPA/Support/LayoutHelper.hpp index 7ed83c0dfc..fb512fc90a 100644 --- a/src/Accelerators/NNPA/Support/LayoutHelper.hpp +++ b/src/Accelerators/NNPA/Support/LayoutHelper.hpp @@ -4,13 +4,14 @@ //===---------- LayoutHelper.hpp - NNPA Layout Helper ---------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_LAYOUT_HELPER_H +#define ONNX_MLIR_LAYOUT_HELPER_H #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/PatternMatch.h" @@ -45,3 +46,4 @@ bool isNHWCLayout(mlir::StringAttr layout); mlir::StringAttr getNCHWLayoutAttr(mlir::PatternRewriter &rewriter); } // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Support/NNPALimit.hpp b/src/Accelerators/NNPA/Support/NNPALimit.hpp index 44c5543128..fdf43a65e3 100644 --- a/src/Accelerators/NNPA/Support/NNPALimit.hpp +++ b/src/Accelerators/NNPA/Support/NNPALimit.hpp @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_NNPA_LIMIT_H +#define ONNX_MLIR_NNPA_LIMIT_H #include @@ -42,3 +43,4 @@ static constexpr const char *NNPA_Z16 = "z16"; // and (s=1,e=63,m=510) as the minimum value. static constexpr float DLF16_MAX = (1L << 32) * (1.0 + (510.0 / 512.0)); static constexpr float DLF16_MIN = -1 * (1L << 32) * (1.0 + (510.0 / 512.0)); +#endif diff --git a/src/Accelerators/NNPA/Support/Stickify/Convert.hpp b/src/Accelerators/NNPA/Support/Stickify/Convert.hpp index 5241e445be..8b3c9abce0 100644 --- a/src/Accelerators/NNPA/Support/Stickify/Convert.hpp +++ b/src/Accelerators/NNPA/Support/Stickify/Convert.hpp @@ -4,7 +4,7 @@ //===------- convert.hpp - Data Conversion --------------------------------===// // -// Copyright 2020-2022 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_CONVERT_H +#define ONNX_MLIR_CONVERT_H #include @@ -21,3 +22,4 @@ uint64_t fp32_to_dlf16( float *input_data, uint16_t *output_data, uint64_t nbr_fields_to_convert); uint64_t dlf16_to_fp32( uint16_t *input_data, float *output_data, uint64_t nbr_fields_to_convert); +#endif diff --git a/src/Accelerators/NNPA/Support/Stickify/DLF16Conversion.hpp b/src/Accelerators/NNPA/Support/Stickify/DLF16Conversion.hpp index 6dd1192867..c821187c8a 100644 --- a/src/Accelerators/NNPA/Support/Stickify/DLF16Conversion.hpp +++ b/src/Accelerators/NNPA/Support/Stickify/DLF16Conversion.hpp @@ -4,7 +4,7 @@ //===------- DLF16Conversion.hpp - DLF16 Conversion -----------------------===// // -// Copyright 2020-2022 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DLF16_H +#define ONNX_MLIR_DLF16_H #include #include #include @@ -173,3 +174,4 @@ inline void NNP1::convert(const float &fp, unsigned *vic) { *this = uint; } // NNP1::convert +#endif diff --git a/src/Accelerators/NNPA/Support/Stickify/Stickify.cpp b/src/Accelerators/NNPA/Support/Stickify/Stickify.cpp index 1f6c6261a9..d2ddc767b5 100644 --- a/src/Accelerators/NNPA/Support/Stickify/Stickify.cpp +++ b/src/Accelerators/NNPA/Support/Stickify/Stickify.cpp @@ -42,14 +42,14 @@ zdnn_status verify_transformed_descriptor(const zdnn_tensor_desc *tfrmd_desc); #define ZDNN_MAX_DIMS 4 // number of dims in AIU's Tensor Descriptor -#define CEIL(a, b) (uint64_t)((a + b - 1) / b) // positive numbers only -#define MIN(a, b) ((a > b) ? b : a) -#define MAX(a, b) ((a < b) ? b : a) +#define CEIL(a, b) (uint64_t)(((a) + (b)-1) / (b)) // positive numbers only +#define MIN(a, b) (((a) > (b)) ? (b) : (a)) +#define MAX(a, b) (((a) < (b)) ? (b) : (a)) #define BIT_SIZEOF(a) (sizeof(a) * 8) // padded = next multiple of AIU_2BYTE_CELLS_PER_STICK #define PADDED(x) \ - ((uint32_t)CEIL(x, AIU_2BYTE_CELLS_PER_STICK) * AIU_2BYTE_CELLS_PER_STICK) + ((uint32_t)CEIL((x), AIU_2BYTE_CELLS_PER_STICK) * AIU_2BYTE_CELLS_PER_STICK) #define ZDNN_STATUS_OK ZDNN_OK typedef enum elements_mode { @@ -92,8 +92,8 @@ DECLARE_DATA_FORMAT_STR(ZDNN_FORMAT_4DKERNEL) static short get_data_layout_num_gates(zdnn_data_layouts layout) { #define CASE_RTN_GATES(a, b) \ - case a: \ - return b; + case (a): \ + return (b); switch (layout) { CASE_RTN_GATES(ZDNN_BIDIR_ZRH, 3); @@ -109,8 +109,8 @@ static short get_data_layout_num_gates(zdnn_data_layouts layout) { static short get_data_layout_dims(zdnn_data_layouts layout) { #define CASE_RTN_DIM(a, b) \ - case a: \ - return b; + case (a): \ + return (b); switch (layout) { CASE_RTN_DIM(ZDNN_1D, 1); @@ -152,7 +152,7 @@ uint32_t get_rnn_concatenated_dim2(uint32_t val, zdnn_concat_info info) { short get_func_code_num_gates(nnpa_function_code func_code) { #define CASE_RTN_GATES(a, b) \ - case a: \ + case (a): \ return get_data_layout_num_gates(b); // piggyback thus no need to hardcode switch (func_code) { @@ -167,7 +167,7 @@ short get_func_code_num_gates(nnpa_function_code func_code) { const char *get_data_layout_str(zdnn_data_layouts layout) { #define CASE_RTN_STR(a) \ - case a: \ + case (a): \ return DATA_LAYOUT_STR_##a; switch (layout) { @@ -194,7 +194,7 @@ const char *get_data_layout_str(zdnn_data_layouts layout) { const char *get_data_format_str(zdnn_data_formats format) { #define CASE_RTN_STR(a) \ - case a: \ + case (a): \ return DATA_FORMAT_STR_##a; switch (format) { @@ -209,8 +209,8 @@ const char *get_data_format_str(zdnn_data_formats format) { short get_data_type_size(zdnn_data_types type) { #define CASE_RTN_SIZE(a, b) \ - case a: \ - return b; + case (a): \ + return (b); switch (type) { CASE_RTN_SIZE(BFLOAT, 2); diff --git a/src/Accelerators/NNPA/Support/Stickify/Stickify.hpp b/src/Accelerators/NNPA/Support/Stickify/Stickify.hpp index eeb7e18d2c..9bc1284f0c 100644 --- a/src/Accelerators/NNPA/Support/Stickify/Stickify.hpp +++ b/src/Accelerators/NNPA/Support/Stickify/Stickify.hpp @@ -4,7 +4,7 @@ //===------- stickify.hpp - Data Stickify ---------------------------------===// // -// Copyright 2020-2022 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_STICKIFY_H +#define ONNX_MLIR_STICKIFY_H #include "zdnn.h" #include "llvm/ADT/ArrayRef.h" @@ -65,3 +66,4 @@ void allochelper_ztensor_free(zdnn_ztensor *ztensor); /// ZDNN_CONVERT_FAILURE /// zdnn_status stickify(zdnn_ztensor *ztensor, ...); +#endif diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp index 59a57128a7..4db736a82d 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -67,17 +67,23 @@ class UnstickExpansionPattern : public OpRewritePattern { ZLowUnstickOp unstickOp, PatternRewriter &rewriter) const override { // Generic way to handle all formats listed below. + // Did not add the HWCK as this is typically for constants and want to + // preserve the high level constant propagation of constant values into the + // Convolution filters. StringAttr layout = unstickOp.getLayoutAttr(); if (layout.getValue().equals_insensitive("4D") || layout.getValue().equals_insensitive("3D") || layout.getValue().equals_insensitive("2D") || - layout.getValue().equals_insensitive("3DS")) { + layout.getValue().equals_insensitive("3DS") || + (layout.getValue().equals_insensitive("NHWC"))) { return generateUnstickCodeNoBuffer(rewriter, unstickOp); } // Otherwise, we don't replace and keep the zdnn call. return failure(); } + // The only requirement for this code to generate the proper code is that E1 + // is been sticked by 64. LogicalResult generateUnstickCodeNoBuffer( PatternRewriter &rewriter, ZLowUnstickOp unstickOp) const { Operation *op = unstickOp.getOperation(); @@ -93,20 +99,20 @@ class UnstickExpansionPattern : public OpRewritePattern { int64_t rank = outputDims.size(); // Info for SIMD Vector Length (VL) and associated types. - int64_t VL = 8; // FP16 VL. - int64_t VLHalf = VL / 2; // FP32 VL. - assert(64 % VL == 0 && "SIMD vector length must divide 64"); + int64_t archVL = 8; // FP16 archVL. + int64_t archVLHalf = archVL / 2; // FP32 archVL. + assert(64 % archVL == 0 && "SIMD vector length must divide 64"); Type f16Type = rewriter.getF16Type(); Type f32Type = rewriter.getF32Type(); - VectorType vecF16Type = VectorType::get({VL}, f16Type); - MemRefType bufferType = MemRefType::get({VL}, f32Type); + VectorType vecF16Type = VectorType::get({archVL}, f16Type); + MemRefType bufferType = MemRefType::get({archVL}, f32Type); // Define useful literals. - IndexExpr litZero = LiteralIndexExpr(0); - IndexExpr lit1 = LiteralIndexExpr(1); - IndexExpr litVLHalf = LiteralIndexExpr(VLHalf); - IndexExpr litVL = LiteralIndexExpr(VL); - IndexExpr lit64 = LiteralIndexExpr(64); + IndexExpr litZero = LitIE(0); + IndexExpr lit1 = LitIE(1); + IndexExpr litArchVLHalf = LitIE(archVLHalf); + IndexExpr litArchVL = LitIE(archVL); + IndexExpr lit64 = LitIE(64); // Useful references for indexing dimensions (neg val are not used). int64_t E1 = rank - 1; @@ -137,7 +143,7 @@ class UnstickExpansionPattern : public OpRewritePattern { // tiles. Since we don't allocate, it is just a "view", we only need to // index by the "tile size", it is sufficient to assume 2 or more. Tiles are // 64. - IndexExpr T = LiteralIndexExpr(2); + IndexExpr T = LitIE(2); DimsExpr reallocTileDims = {T, lit64}; Value inputAsTx64 = create.mem.reinterpretCast(input, litZero.getValue(), reallocTileDims); @@ -155,7 +161,7 @@ class UnstickExpansionPattern : public OpRewritePattern { // Translate the tile index t1 to the actual targetted data. Value inputOffset = create.krnl.getLinearOffsetIndexIE(input, inputAF); - IndexExpr inputDataOffset = SymbolIndexExpr(inputOffset); + IndexExpr inputDataOffset = SymIE(inputOffset); IndexExpr inputTileOffset = inputDataOffset.floorDiv(64); // Prefetch @@ -179,7 +185,7 @@ class UnstickExpansionPattern : public OpRewritePattern { // I may process here up to [e1 ... e1 + m*64), make sure its // not going out of bound, i.e. beyond outputDIms[E1]; IndexExpr ub1 = SymIE(outputDims[E1]); - IndexExpr lit64Bis = LiteralIndexExpr(64); + IndexExpr lit64Bis = LitIE(64); IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64, ub1); IndexExpr isFullLogical = isFull >= 0; create.scf.ifThenElse( @@ -189,21 +195,23 @@ class UnstickExpansionPattern : public OpRewritePattern { [&](SCFBuilder b) { MDBuilder create(b); // Loop (tried unroll of 2 and 8, 4 was best). - const int64_t U = 4; - assert(U * VL <= 64 && "bad unroll"); - create.scf.forLoop(litZero.getValue(), lit64.getValue(), U * VL, + const int64_t unrollVL = 4; + const int64_t totVL = unrollVL * archVL; + assert(totVL <= 64 && "bad unroll"); + create.scf.forLoop(litZero.getValue(), lit64.getValue(), totVL, [&](SCFBuilder b, Value loopIndex) { MDBuilder create(b); IndexExprScope innerScope(b, &outerScope); IndexExpr l = DimIE(loopIndex); - Value vecF16[U], vecF32H[U], vecF32L[U]; + Value vecF16[unrollVL], vecF32H[unrollVL], + vecF32L[unrollVL]; // Load f16 values from input via reinterpreted data tile. - for (int64_t i = 0; i < U; ++i) { + for (int64_t i = 0; i < unrollVL; ++i) { vecF16[i] = create.vec.loadIE(vecF16Type, inputAsTx64, - {SymIE(inputTileOffset), l + (i * VL)}, {}); + {SymIE(inputTileOffset), l + (i * archVL)}, {}); } // Convert back to f32. - for (int64_t i = 0; i < U; ++i) { + for (int64_t i = 0; i < unrollVL; ++i) { auto convertOp = rewriter.create( loc, vecF16[i]); @@ -213,8 +221,8 @@ class UnstickExpansionPattern : public OpRewritePattern { // Store f32 values back to the (normal layout) output. DimsExpr outputAF = SymListIE(inputAF); outputAF[E1] = outputAF[E1] + l; - for (int64_t i = 0; i < U; ++i) { - LiteralIndexExpr iH(i * VL), iL(i * VL + VL / 2); + for (int64_t i = 0; i < unrollVL; ++i) { + LitIE iH(i * archVL), iL(i * archVL + archVL / 2); create.vec.storeIE( vecF32H[i], alloc, outputAF, {iH.getValue()}); create.vec.storeIE( @@ -231,9 +239,10 @@ class UnstickExpansionPattern : public OpRewritePattern { // all as we subtract (VL-1). Aka if VL=8 and tripCount = 16, // tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we iterate // over i=0 & i=8 as both are < 9. - IndexExpr tripCountWithoutPartialLastVL = tripCount - (VL - 1); + IndexExpr tripCountWithoutPartialLastVL = + tripCount - (archVL - 1); create.scf.forLoop(litZero.getValue(), - tripCountWithoutPartialLastVL.getValue(), VL, + tripCountWithoutPartialLastVL.getValue(), archVL, [&](SCFBuilder b, Value loopIndex) { MDBuilder create(b); IndexExprScope innerScope(b, &middleScope); @@ -252,10 +261,10 @@ class UnstickExpansionPattern : public OpRewritePattern { outputAF[E1] = outputAF[E1] + l; create.vec.storeIE(vecF32H, alloc, outputAF, {}); create.vec.storeIE( - vecF32L, alloc, outputAF, {litVLHalf.getValue()}); + vecF32L, alloc, outputAF, {litArchVLHalf.getValue()}); }); // Deal with the last values: compute f32 using simd. - IndexExpr remainingScalarValues = tripCount % VL; + IndexExpr remainingScalarValues = tripCount % archVL; IndexExpr lastL = tripCount - remainingScalarValues; Value vecF16 = create.vec.loadIE(vecF16Type, inputAsTx64, {SymIE(inputTileOffset), lastL}, {}); @@ -264,10 +273,10 @@ class UnstickExpansionPattern : public OpRewritePattern { rewriter.create(loc, vecF16); Value vecF32H = convertOp.getResult(0); Value vecF32L = convertOp.getResult(1); - // Save into VL value buffer. + // Save into archVL value buffer. Value bufferF32 = create.mem.alignedAlloca(bufferType); create.vec.storeIE(vecF32H, bufferF32, {litZero}, {}); - create.vec.storeIE(vecF32L, bufferF32, {litVLHalf}, {}); + create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf}, {}); // Save the remaining values as scalars. create.scf.forLoop(litZero.getValue(), remainingScalarValues.getValue(), 1, @@ -306,17 +315,23 @@ class StickExpansionPattern : public OpRewritePattern { StringAttr layout = stickOp.getLayoutAttr(); // Generic way to handle all formats listed below. + // Did not add the HWCK as this is typically for constants and want to + // preserve the high level constant propagation of constant values into the + // Convolution filters. if (layout.getValue().equals_insensitive("4D") || layout.getValue().equals_insensitive("3D") || layout.getValue().equals_insensitive("2D") || - layout.getValue().equals_insensitive("3DS")) { + layout.getValue().equals_insensitive("3DS") || + layout.getValue().equals_insensitive("NHWC")) { return generateStickCodeNoBuffer(rewriter, stickOp); } // Otherwise, we don't replace and keep the zdnn call. return failure(); } - /* Version without buffer, more like zdnn */ + // Version without buffer, more like zdnn. + // The only requirement for this code to generate the proper code is that E1 + // is been sticked by 64. LogicalResult generateStickCodeNoBuffer( PatternRewriter &rewriter, ZLowStickOp stickOp) const { Operation *op = stickOp.getOperation(); @@ -335,17 +350,16 @@ class StickExpansionPattern : public OpRewritePattern { int64_t rank = outputDims.size(); // Info for SIMD Vector Length (VL) and associated types. - int64_t VL = 8; // FP16 VL. - int64_t VLHalf = VL / 2; // FP32 VL. - assert(64 % VL == 0 && "SIMD vector length must divide 64"); + int64_t archVL = 8; // FP16 archVL. + int64_t archVLHalf = archVL / 2; // FP32 archVL. + assert(64 % archVL == 0 && "SIMD vector length must divide 64"); Type f32Type = rewriter.getF32Type(); - VectorType vecF32Type = VectorType::get({VLHalf}, f32Type); + VectorType vecF32Type = VectorType::get({archVLHalf}, f32Type); // Define useful literals. - IndexExpr litZero = LiteralIndexExpr(0); - IndexExpr lit1 = LiteralIndexExpr(1); - IndexExpr litVLHalf = LiteralIndexExpr(VLHalf); - IndexExpr lit64 = LiteralIndexExpr(64); + IndexExpr litZero = LitIE(0); + IndexExpr lit1 = LitIE(1); + IndexExpr lit64 = LitIE(64); // Values for saturation. Value vecDlf16Min, vecDlf16Max; @@ -367,6 +381,28 @@ class StickExpansionPattern : public OpRewritePattern { IndexExpr T1 = outputDims[E1].ceilDiv(64); ubs[E1] = T1; // E1 dim is over tiles. + // If outputDims[E1] is constant and < 64, then T1 is 1 (ok), and we can + // iterate over fewer values in the SIMD loop. + IndexExpr simdLoopUB = lit64; + int64_t unrollVL = 4; // Unrolling of SIMD loop: tried 2 and 8, 4 was best. + if (outputDims[E1].isLiteral()) { + int64_t d1 = outputDims[E1].getLiteral(); + if (d1 < 64) { + // Shrink unrollVL if suitable. + if (d1 <= archVL) + unrollVL = 1; + else if (d1 <= 2 * archVL) + unrollVL = 2; + else if (d1 <= 3 * archVL) + unrollVL = 3; + double trip = unrollVL * archVL; + int64_t ub = std::ceil((1.0 * d1) / trip) * trip; + simdLoopUB = LitIE(ub); + } + } + int64_t totVL = unrollVL * archVL; + assert(totVL <= 64 && "bad unroll"); + // Parallel... if (enableParallel) { int64_t parId; @@ -385,7 +421,7 @@ class StickExpansionPattern : public OpRewritePattern { // tiles. Since we don't allocate, it is just a "view", we only need to // index by the "tile size", it is sufficient to assume 2 or more. Tiles are // 64 elements. - IndexExpr T = LiteralIndexExpr(2); + IndexExpr T = LitIE(2); DimsExpr reallocTileDims = {T, lit64}; Value allocAsTx64 = create.mem.reinterpretCast(alloc, litZero.getValue(), reallocTileDims); @@ -396,7 +432,7 @@ class StickExpansionPattern : public OpRewritePattern { MDBuilder create(b); IndexExprScope outerScope(create.krnl, &allocScope); DimsExpr outerIndices; - getIndexExprList(loopInd, outerIndices); + getIndexExprList(loopInd, outerIndices); DimsExpr memAF = outerIndices; memAF[E1] = memAF[E1] * 64; // Loop index for E1 is in tiles of 64. Value allocOffset = create.krnl.getLinearOffsetIndexIE(alloc, memAF); @@ -418,21 +454,19 @@ class StickExpansionPattern : public OpRewritePattern { #endif #endif - const int64_t U = 4; // Tried 2 and 8, 4 was best. - assert(U * VL <= 64 && "bad unroll"); - create.affine.forIE(litZero, lit64, U * VL, + create.affine.forIE(litZero, simdLoopUB, totVL, [&](AffineBuilder &b, ValueRange loopInd) { MDBuilder create(b); DimsExpr inputAF; IndexExprScope innerScope(create.krnl, &outerScope); - SymbolIndexExpr l(loopInd[0]); - getIndexExprList(memAF, inputAF); + SymIE l(loopInd[0]); + getIndexExprList(memAF, inputAF); // E1: add the "l" local E1 offset. inputAF[E1] = inputAF[E1] + l; // Load the f32. - Value vecF32H[U], vecF32L[U], vecF16[U]; - for (int64_t u = 0; u < U; ++u) { - LiteralIndexExpr iH(u * VL), iL(u * VL + VL / 2); + Value vecF32H[unrollVL], vecF32L[unrollVL], vecF16[unrollVL]; + for (int64_t u = 0; u < unrollVL; ++u) { + LitIE iH(u * archVL), iL(u * archVL + archVL / 2); vecF32H[u] = create.vec.loadIE( vecF32Type, input, inputAF, {iH.getValue()}); vecF32L[u] = create.vec.loadIE( @@ -440,25 +474,25 @@ class StickExpansionPattern : public OpRewritePattern { } if (saturation) { // Get rid of too-high values. - for (int64_t u = 0; u < U; ++u) { + for (int64_t u = 0; u < unrollVL; ++u) { vecF32H[u] = create.math.min(vecF32H[u], vecDlf16Max); vecF32L[u] = create.math.min(vecF32L[u], vecDlf16Max); } // Get rid of too-low values. - for (int64_t u = 0; u < U; ++u) { + for (int64_t u = 0; u < unrollVL; ++u) { vecF32H[u] = create.math.max(vecF32H[u], vecDlf16Min); vecF32L[u] = create.math.max(vecF32L[u], vecDlf16Min); } } // Convert f32 to dlfloat16. - for (int64_t u = 0; u < U; ++u) { + for (int64_t u = 0; u < unrollVL; ++u) { vecF16[u] = rewriter.create( loc, vecF32H[u], vecF32L[u]); } // Store the dlfloat16. - for (int64_t u = 0; u < U; ++u) { + for (int64_t u = 0; u < unrollVL; ++u) { create.vec.storeIE(vecF16[u], allocAsTx64, - {SymIE(allocTileIndex), l + (u * VL)}, {}); + {SymIE(allocTileIndex), l + (u * archVL)}, {}); } }); }); diff --git a/src/Builder/FrontendDialectHelper.hpp b/src/Builder/FrontendDialectHelper.hpp index eb3de87a79..d2d219a390 100644 --- a/src/Builder/FrontendDialectHelper.hpp +++ b/src/Builder/FrontendDialectHelper.hpp @@ -4,7 +4,7 @@ //===--------------------- FrontendDialectHelper.hpp ----------------------===// // -// Copyright 2019 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_FRONTEND_HELPER_H +#define ONNX_MLIR_FRONTEND_HELPER_H #include "mlir/IR/BuiltinAttributeInterfaces.h" @@ -26,3 +27,4 @@ mlir::ElementsAttr onnxTensorProtoToElmAttr(mlir::MLIRContext *ctx, const std::string &externalDataDir, const onnx::TensorProto &initializer); } // namespace onnx_mlir +#endif diff --git a/src/Builder/FrontendDialectTransformer.hpp b/src/Builder/FrontendDialectTransformer.hpp index 36283abbc6..bfac88ba4c 100644 --- a/src/Builder/FrontendDialectTransformer.hpp +++ b/src/Builder/FrontendDialectTransformer.hpp @@ -4,13 +4,14 @@ //===--------- FrontendDialectTransformer.hpp - MLIR Operations -----------===// // -// Copyright 2019 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_FRONTEND_TRANSFORMER_H +#define ONNX_MLIR_FRONTEND_TRANSFORMER_H #include #include @@ -109,3 +110,4 @@ void ImportFrontendModel(const onnx::ModelProto &model, * operations specific to other frameworks such as Tensorflow or Pytorch. */ } // namespace onnx_mlir +#endif diff --git a/src/Builder/ImportONNXUtils.hpp b/src/Builder/ImportONNXUtils.hpp index 5f274b9fcb..d1c1e23aa3 100644 --- a/src/Builder/ImportONNXUtils.hpp +++ b/src/Builder/ImportONNXUtils.hpp @@ -4,7 +4,7 @@ //===--------------------- ImportONNXUtils.hpp ----------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,10 +12,12 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_IMPORT_UTILS_H +#define ONNX_MLIR_IMPORT_UTILS_H #include "onnx/onnx_pb.h" bool IsTopologicallySorted(const onnx::GraphProto &graph); bool SortGraph(onnx::GraphProto *graph); +#endif diff --git a/src/Builder/ModelInputShaper.hpp b/src/Builder/ModelInputShaper.hpp index a6dcbccb18..6af7724429 100644 --- a/src/Builder/ModelInputShaper.hpp +++ b/src/Builder/ModelInputShaper.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_MODEL_INPUT_H +#define ONNX_MLIR_MODEL_INPUT_H #include #include @@ -88,3 +89,4 @@ class ModelInputShaper { }; } // namespace onnx_mlir +#endif diff --git a/src/Builder/SymbolTable.hpp b/src/Builder/SymbolTable.hpp index 0abf07e69a..b87f149471 100644 --- a/src/Builder/SymbolTable.hpp +++ b/src/Builder/SymbolTable.hpp @@ -2,7 +2,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -#pragma once +#ifndef ONNX_MLIR_SYMBOL_TABLE_H +#define ONNX_MLIR_SYMBOL_TABLE_H #include #include @@ -164,3 +165,4 @@ bool VariableScope::contains(const std::string &name) const { } } // namespace onnx_mlir +#endif diff --git a/src/Compiler/CompilerDialects.hpp b/src/Compiler/CompilerDialects.hpp index f61303640e..b596f51442 100644 --- a/src/Compiler/CompilerDialects.hpp +++ b/src/Compiler/CompilerDialects.hpp @@ -4,7 +4,8 @@ //===------------------------ CompilerDialects.hpp ------------------------===// -#pragma once +#ifndef ONNX_MLIR_COMPILER_DIALECTS_H +#define ONNX_MLIR_COMPILER_DIALECTS_H #include "src/Accelerators/Accelerator.hpp" @@ -19,3 +20,4 @@ mlir::DialectRegistry registerDialects( llvm::ArrayRef accels); } // namespace onnx_mlir +#endif diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index fbce585571..6d010bd219 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -64,7 +64,7 @@ std::string mllvm; // onnx-mlir only std::string instrumentOps; // onnx-mlir only unsigned instrumentControlBits; // onnx-mlir only std::string parallelizeOps; // onnx-mlir only -bool instrumentONNXSignature; // onnx-mlir only +std::string instrumentSignatures; // onnx-mlir only std::string ONNXOpStats; // onnx-mlir only int onnxOpTransformThreshold; // onnx-mlir only bool onnxOpTransformReport; // onnx-mlir only @@ -432,10 +432,17 @@ static llvm::cl::opt parallelizeOpsOpt("parallelize-ops", llvm::cl::location(parallelizeOps), llvm::cl::init(""), llvm::cl::cat(OnnxMlirOptions)); -static llvm::cl::opt instrumentONNXSignatureOpt( - "instrument-onnx-signature", - llvm::cl::desc("Instrument ONNX ops to print the type of their inputs"), - llvm::cl::location(instrumentONNXSignature), llvm::cl::init(false), +static llvm::cl::opt instrumentSignatureOpt( + "instrument-signature", + llvm::cl::desc("Specify which high-level operations should print their" + " input type(s) and shape(s)\n" + "\"ALL\" or \"\" for all available operations,\n" + "\"NONE\" for no instrument (default),\n" + "\"ops1,ops2, ...\" for the multiple ops.\n" + "e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add ops.\n" + "Asterisk is also available.\n" + "e.g. \"onnx.*\" for all onnx operations.\n"), + llvm::cl::location(instrumentSignatures), llvm::cl::init("NONE"), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt ONNXOpStatsOpt("onnx-op-stats", @@ -593,6 +600,25 @@ static llvm::cl::opt enable_bound_check("enable-bound-check", llvm::cl::location(enableBoundCheck), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); +#if defined(_DEBUG) +// Option only available in debug mode: set using command options. +static llvm::cl::opt test_compiler_opt("test-compiler-opt", + llvm::cl::desc( + "Help compiler writers test a new (small) optimization. When false, " + "the old approach should be used. When true, the new opt should be " + "used. Utilities such as CheckONNXModel.py can then verify that the " + "new opt deliver the same results.\n" + "E.g. CheckONNXModel.py -m test.mlir -t -O3 -a test-compiler-opt=true\n" + "Once the new opt works, it should not rely this option any more.\n" + "Only defined in DEBUG build and default to false.\n"), + llvm::cl::location(debugTestCompilerOpt), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirOptions)); +bool debugTestCompilerOpt; +#else +// Option only available in debug mode: disable when not in debug. +bool debugTestCompilerOpt = false; +#endif + // Options for onnx-mlir-opt only static llvm::cl::opt split_input_file_opt("split-input-file", llvm::cl::desc("Split the input file into pieces and process each " diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 88dbdbcbad..2ed9f251e1 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -4,7 +4,7 @@ //===------------------------ CompilerOptions.hpp -------------------------===// // -// Copyright 2022, 2023 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_COMPILER_OPTIONS_H +#define ONNX_MLIR_COMPILER_OPTIONS_H #include "onnx-mlir/Compiler/OMCompilerTypes.h" #include "src/Accelerators/Accelerator.hpp" #include "llvm/Support/CommandLine.h" @@ -108,7 +109,7 @@ extern std::string mllvm; // onnx-mlir only extern std::string instrumentOps; // onnx-mlir only extern unsigned instrumentControlBits; // onnx-mlir only extern std::string parallelizeOps; // onnx-mlir only -extern bool instrumentONNXSignature; // onnx-mlir only +extern std::string instrumentSignatures; // onnx-mlir only extern std::string ONNXOpStats; // onnx-mlir only extern int onnxOpTransformThreshold; // onnx-mlir only extern bool onnxOpTransformReport; // onnx-mlir only @@ -130,6 +131,8 @@ extern OptReport optReport; // onnx-mlir only extern bool useOldBufferization; // onnx-mlir only extern bool enableTiming; // onnx-mlir only extern bool enableBoundCheck; // onnx-mlir only +extern bool debugTestCompilerOpt; // onnx-mlir only + extern bool split_input_file; // onnx-mlir-opt only extern bool verify_diagnostics; // onnx-mlir-opt only extern bool verify_passes; // onnx-mlir-opt only @@ -211,3 +214,4 @@ void removeUnrelatedOptions( void initCompilerConfig(); } // namespace onnx_mlir +#endif diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index a4d303522b..c8f84e5565 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -157,7 +157,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { } void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, - bool enableInstrumentONNXSignature, std::string ONNXOpsStatFormat) { + std::string instrumentSignatureString, std::string ONNXOpsStatFormat) { if (enableCSE) // Eliminate common sub-expressions before lowering to Krnl. // TODO: enable this by default when we make sure it works flawlessly. @@ -182,10 +182,11 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, } // Print Signatures of each op at runtime if enabled. Should not run - // signature and instrument passes at the same time. - if (enableInstrumentONNXSignature) - pm.addNestedPass( - onnx_mlir::createInstrumentONNXSignaturePass()); + // signature and instrument passes at the same time as time may include printf + // overheads. + if (instrumentSignatureString != "NONE") + pm.addNestedPass(onnx_mlir::createInstrumentONNXSignaturePass( + instrumentSignatureString)); pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3, /*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel, /*opsToCall*/ opsForCall)); @@ -304,7 +305,7 @@ void addPasses(mlir::OwningOpRef &module, mlir::PassManager &pm, if (emissionTarget >= EmitMLIR) { if (inputIRLevel <= ONNXLevel) addONNXToKrnlPasses(pm, OptimizationLevel, /*enableCSE*/ true, - instrumentONNXSignature, ONNXOpStats); + instrumentSignatures, ONNXOpStats); if (inputIRLevel <= MLIRLevel) addKrnlToAffinePasses(pm); } diff --git a/src/Compiler/CompilerPasses.hpp b/src/Compiler/CompilerPasses.hpp index aaf2115b8c..f0c0499f8f 100644 --- a/src/Compiler/CompilerPasses.hpp +++ b/src/Compiler/CompilerPasses.hpp @@ -4,7 +4,7 @@ //===------------------------- CompilerPasses.hpp -------------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_COMPILER_PASSES_H +#define ONNX_MLIR_COMPILER_PASSES_H #include "mlir/Pass/PassManager.h" namespace onnx_mlir { @@ -21,7 +22,7 @@ void configurePasses(); void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU); void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, - bool enableInstrumentONNXSignature, std::string ONNXOpsStatFilename); + std::string instrumentSignatureString, std::string ONNXOpsStatFilename); void addKrnlToAffinePasses(mlir::PassManager &pm); void addKrnlToLLVMPasses( mlir::OpPassManager &pm, std::string outputNameNoExt, bool enableCSE); @@ -30,3 +31,4 @@ InputIRLevelType determineInputIRLevel( void addPasses(mlir::OwningOpRef &module, mlir::PassManager &pm, EmissionTargetType emissionTarget, std::string outputNameNoExt); } // namespace onnx_mlir +#endif diff --git a/src/Compiler/CompilerUtils.cpp b/src/Compiler/CompilerUtils.cpp index 476edc81e7..67c53ce1b3 100644 --- a/src/Compiler/CompilerUtils.cpp +++ b/src/Compiler/CompilerUtils.cpp @@ -62,7 +62,7 @@ namespace onnx_mlir { // Values to report the current phase of compilation. // Increase TOTAL_COMPILE_PHASE when having more phases. uint64_t CURRENT_COMPILE_PHASE = 1; -uint64_t TOTAL_COMPILE_PHASE = 5; +uint64_t TOTAL_COMPILE_PHASE = 6; // Make a function that forces preserving all files using the runtime arguments // and/or the overridePreserveFiles enum. @@ -170,18 +170,37 @@ int Command::exec(std::string wdir) const { } void showCompilePhase(std::string msg) { - time_t rawtime; - struct tm *timeinfo; + time_t rawTime; + struct tm *timeInfo; char buffer[80]; + // Remember first time. + static time_t firstRawTime; + static bool hasFirstRawTime = false; // Get current date. - time(&rawtime); - timeinfo = localtime(&rawtime); - strftime(buffer, 80, "%c", timeinfo); + time(&rawTime); + timeInfo = localtime(&rawTime); + strftime(buffer, 80, "%c", timeInfo); std::string currentTime(buffer); + // Compute time difference in seconds. + int diff = 0; + if (hasFirstRawTime) { + diff = difftime(rawTime, firstRawTime); + } else { + firstRawTime = rawTime; + hasFirstRawTime = true; + } llvm::outs() << "[" << CURRENT_COMPILE_PHASE++ << "/" << TOTAL_COMPILE_PHASE - << "] " << currentTime << " " << msg << "\n"; + << "] " << currentTime << " (" << diff << "s) " << msg << "\n"; + // Flush so that if there are errors, we know where it came from. + llvm::outs().flush(); + + // Reset current phase. + if (CURRENT_COMPILE_PHASE > TOTAL_COMPILE_PHASE) { + CURRENT_COMPILE_PHASE = 1; + hasFirstRawTime = false; + } } } // namespace onnx_mlir @@ -807,6 +826,8 @@ static int emitOutputFiles(std::string outputNameNoExt, } } } + showCompilePhase("Compilation completed"); + return CompilerSuccess; } // end anonymous namespace @@ -923,6 +944,10 @@ int compileModule(mlir::OwningOpRef &module, mlir::MLIRContext &context, std::string outputNameNoExt, EmissionTargetType emissionTarget) { std::string msg = "Compiling and Optimizing MLIR Module"; + // There is no importing phase (e.g. the model is .mlir, not .onnx), adjust to + // correctly reflect the current phase. + if (CURRENT_COMPILE_PHASE == 1) + CURRENT_COMPILE_PHASE++; showCompilePhase(msg); auto compileModuleTiming = rootTimingScope.nest("[onnx-mlir] " + msg); diff --git a/src/Compiler/CompilerUtils.hpp b/src/Compiler/CompilerUtils.hpp index 5bdb5db87b..e3ecc1bd72 100644 --- a/src/Compiler/CompilerUtils.hpp +++ b/src/Compiler/CompilerUtils.hpp @@ -4,7 +4,7 @@ //===-------------------------- CompilerUtils.hpp -------------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_COMPILER_UTILS_H +#define ONNX_MLIR_COMPILER_UTILS_H #include "onnx-mlir/Compiler/OMCompilerTypes.h" @@ -91,3 +92,4 @@ std::string getTargetFilename( const std::string filenameNoExt, EmissionTargetType target); } // namespace onnx_mlir +#endif diff --git a/src/Compiler/DisposableGarbageCollector.hpp b/src/Compiler/DisposableGarbageCollector.hpp index daa54d5cdb..711402ccb2 100644 --- a/src/Compiler/DisposableGarbageCollector.hpp +++ b/src/Compiler/DisposableGarbageCollector.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_GARBAGE_COLLECTOR_H +#define ONNX_MLIR_GARBAGE_COLLECTOR_H #include "mlir/Pass/PassInstrumentation.h" @@ -31,3 +32,4 @@ struct DisposableGarbageCollector : public mlir::PassInstrumentation { }; } // namespace onnx_mlir +#endif diff --git a/src/Compiler/ExternalUtil.hpp.in b/src/Compiler/ExternalUtil.hpp.in index dbb266140f..fc599ece69 100644 --- a/src/Compiler/ExternalUtil.hpp.in +++ b/src/Compiler/ExternalUtil.hpp.in @@ -1,4 +1,5 @@ -#pragma once +#ifndef ONNX_MLIR_EXTERNAL_UTIL_H +#define ONNX_MLIR_EXTERNAL_UTIL_H #include #include @@ -26,3 +27,4 @@ static const std::map toolPathMap = { {"jar", kJarPath}, {"defaultTriple", kDefaultTriple}, {"lrodataScript", kLrodataScript}}; } // namespace onnx_mlir +#endif diff --git a/src/Compiler/HeapReporter.hpp b/src/Compiler/HeapReporter.hpp index 0673195a5d..3082cfcbe1 100644 --- a/src/Compiler/HeapReporter.hpp +++ b/src/Compiler/HeapReporter.hpp @@ -10,7 +10,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_HEAP_REPORTER_H +#define ONNX_MLIR_HEAP_REPORTER_H #include #include @@ -39,3 +40,4 @@ struct HeapReporter : public mlir::PassInstrumentation { }; } // namespace onnx_mlir +#endif diff --git a/src/Compiler/OptionUtils.hpp b/src/Compiler/OptionUtils.hpp index bea68f7c98..535be087bb 100644 --- a/src/Compiler/OptionUtils.hpp +++ b/src/Compiler/OptionUtils.hpp @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_OPTION_UTILS_H +#define ONNX_MLIR_OPTION_UTILS_H #include "onnx-mlir/Compiler/OMCompilerTypes.h" @@ -61,3 +62,4 @@ class EnableByRegexOption { }; } // namespace onnx_mlir +#endif diff --git a/src/Conversion/KrnlSeqToMemref/ConvertSeqToMemref.hpp b/src/Conversion/KrnlSeqToMemref/ConvertSeqToMemref.hpp index 4761a5c411..5397a4d12d 100644 --- a/src/Conversion/KrnlSeqToMemref/ConvertSeqToMemref.hpp +++ b/src/Conversion/KrnlSeqToMemref/ConvertSeqToMemref.hpp @@ -4,7 +4,7 @@ //====------ ConvertSeqToMemrefM.hpp - Krnl Dialect Lowering //---------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_CONVERT_SEQ_H +#define ONNX_MLIR_CONVERT_SEQ_H #include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Pass/Passes.hpp" @@ -35,3 +36,4 @@ void populateLoweringKrnlSeqStoreOpPattern(mlir::TypeConverter &typeConverter, } // namespace krnl } // namespace onnx_mlir +#endif diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp index 45d5b211a9..2bc0fd3aae 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp @@ -4,7 +4,7 @@ //====------ ConvertKrnlToAffine.hpp - Krnl Dialect Lowering --------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_CONVERT_KRNL_TO_AFFINE_H +#define ONNX_MLIR_CONVERT_KRNL_TO_AFFINE_H #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" @@ -89,3 +90,4 @@ void populateLoweringKrnlTerminatorOpPattern(mlir::TypeConverter &typeConverter, } // namespace krnl } // namespace onnx_mlir +#endif diff --git a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp index c33bc72fc0..289d394a10 100644 --- a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp +++ b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp @@ -4,7 +4,7 @@ //===-------------- KrnlMatmul.cpp - Lower KrnlMatmulOp -------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -131,10 +131,10 @@ class KrnlMatmulLowering : public ConversionPattern { if (iComputeTileSize.isLiteral() && kComputeTileSize.isLiteral()) { uint64_t i = iComputeTileSize.getLiteral(); uint64_t k = kComputeTileSize.getLiteral(); - uint64_t mVL = create.vec.getMachineVectorLength(elementType); - if (i % mVL == 0 && k % mVL == 0) { - // Right now, vector length must be mVL. - vectorLen = LiteralIndexExpr(mVL); + uint64_t archVL = create.vec.getArchVectorLength(elementType); + if (i % archVL == 0 && k % archVL == 0) { + // Right now, vector length must be archVL. + vectorLen = LiteralIndexExpr(archVL); } else { simdize = false; LLVM_DEBUG(llvm::dbgs() << "Matmul: mat*vec with bad sizes: i " << i @@ -351,14 +351,14 @@ class KrnlMatmulLowering : public ConversionPattern { MemRefBuilder, KrnlBuilder> create(createAffine); int64_t iLit(I.getLiteral()), VL(vectorLen.getLiteral()); - int64_t mVL = create.vec.getMachineVectorLength(elementType); + int64_t archVL = create.vec.getArchVectorLength(elementType); // Get operands. KrnlMatMulOpAdaptor operandAdaptor = KrnlMatMulOpAdaptor(op); Value A(operandAdaptor.getA()), B(operandAdaptor.getB()), C(operandAdaptor.getC()); // Generate the vector type conversions. - assert(VL == mVL && "vector length and VL must be identical for now"); + assert(VL == archVL && "vector length and VL must be identical for now"); VectorType vecType = VectorType::get({VL}, elementType); int64_t iUnrollFactor = iLit; assert(iUnrollFactor % VL == 0 && "i blocking should be a multiple of VL"); @@ -405,7 +405,7 @@ class KrnlMatmulLowering : public ConversionPattern { } }); - // Reduce each SIMD vector of length mVL using a SIMD parallel reduction. + // Reduce each SIMD vector of length VL using a SIMD parallel reduction. SmallVector vProdList, vReductionList; for (int64_t i = 0; i < iUnrollFactor; ++i) { Value iVal = create.math.constantIndex(i); diff --git a/src/Conversion/KrnlToAffine/KrnlToAffineHelper.hpp b/src/Conversion/KrnlToAffine/KrnlToAffineHelper.hpp index 3e062cbe79..1aff72bd9c 100644 --- a/src/Conversion/KrnlToAffine/KrnlToAffineHelper.hpp +++ b/src/Conversion/KrnlToAffine/KrnlToAffineHelper.hpp @@ -4,7 +4,7 @@ //===------ KrnlToAffineHelper.hpp ----------------------------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_KRNL_TO_AFFINE_H +#define ONNX_MLIR_KRNL_TO_AFFINE_H #include "src/Dialect/Mlir/IndexExpr.hpp" @@ -25,3 +26,4 @@ IndexExpr trip(IndexExpr UB, IndexExpr block, IndexExpr GI); } // namespace krnl } // namespace onnx_mlir +#endif diff --git a/src/Conversion/KrnlToLLVM/ConstantOpInterface.cpp b/src/Conversion/KrnlToLLVM/ConstantOpInterface.cpp index ffcd65add8..2c31ddee8b 100644 --- a/src/Conversion/KrnlToLLVM/ConstantOpInterface.cpp +++ b/src/Conversion/KrnlToLLVM/ConstantOpInterface.cpp @@ -116,13 +116,14 @@ class ConstantOpInterfaceLowering return mlir::cast(a.getValue()[i]).getInt(); } - LLVM::GlobalOp lowerDenseResourceConstant(ConstantOpInterface &constOpInterface, - Type globalType, ConversionPatternRewriter &rewriter) const { + LLVM::GlobalOp lowerDenseResourceConstant( + ConstantOpInterface &constOpInterface, Type globalType, + ConversionPatternRewriter &rewriter) const { assert(constOpInterface.getValue().has_value() && "Expecting ConstantOpInterface with a valid value"); - assert( - mlir::isa(constOpInterface.getValue().value()) && - "Expecting a global with an dense resource elements attribute"); + assert(mlir::isa( + constOpInterface.getValue().value()) && + "Expecting a global with an dense resource elements attribute"); MLIRContext *context = constOpInterface.getContext(); Location loc = constOpInterface.getLoc(); @@ -132,10 +133,10 @@ class ConstantOpInterfaceLowering OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - auto blob = - mlir::cast(constOpInterface.getValue().value()) - .getRawHandle() - .getBlob(); + auto blob = mlir::cast( + constOpInterface.getValue().value()) + .getRawHandle() + .getBlob(); assert(blob && "Expecting dense resource with a valid blob"); ArrayRef rawData = blob->getData(); @@ -148,8 +149,8 @@ class ConstantOpInterfaceLowering auto llvmArrayI8Ty = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), sizeInBytes); LLVM::GlobalOp global = create.llvm.globalOp(llvmArrayI8Ty, - /*isConstant=*/true, LLVM::Linkage::Internal, constOpInterface.getName(), - llvmStringAttr); + /*isConstant=*/true, LLVM::Linkage::Internal, + constOpInterface.getName(), llvmStringAttr); LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";); return global; @@ -189,8 +190,8 @@ class ConstantOpInterfaceLowering StringRef data(rawData.data(), rawData.size()); StringAttr llvmStringAttr = StringAttr::get(context, data); global = create.llvm.globalOp(llvmArrayI8Ty, - /*isConstant=*/true, LLVM::Linkage::Internal, constOpInterface.getName(), - llvmStringAttr); + /*isConstant=*/true, LLVM::Linkage::Internal, + constOpInterface.getName(), llvmStringAttr); } else { if (mlir::isa(denseAttr.getElementType())) global = lowerStringLiteral(constOpInterface, globalType, rewriter); @@ -209,7 +210,8 @@ class ConstantOpInterfaceLowering ConversionPatternRewriter &rewriter) const { Location loc = constOpInterface.getLoc(); MLIRContext *context = constOpInterface.getContext(); - ModuleOp module = constOpInterface.getOperation()->getParentOfType(); + ModuleOp module = + constOpInterface.getOperation()->getParentOfType(); MultiDialectBuilder create(rewriter, loc); Type llvmI8Ty = IntegerType::get(context, 8); @@ -331,8 +333,8 @@ class ConstantOpInterfaceLowering // block. auto arrayType = LLVM::LLVMArrayType::get(i8PtrType, offsets.size()); auto global = create.llvm.globalOp(arrayType, - /*isConstant=*/true, LLVM::Linkage::Internal, constOpInterface.getName(), - Attribute()); + /*isConstant=*/true, LLVM::Linkage::Internal, + constOpInterface.getName(), Attribute()); Region ®ion = global.getInitializerRegion(); Block *block = builder.createBlock(®ion); diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp index 1efa842dc7..5f945084ca 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp @@ -4,7 +4,7 @@ //====------ ConvertKrnlToLLVM.hpp - Krnl Dialect Lowering ---------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_CONVERT_KRNL_TO_LLVM_H +#define ONNX_MLIR_CONVERT_KRNL_TO_LLVM_H #include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Pass/Passes.hpp" @@ -121,3 +122,4 @@ void genSignatureFunction(mlir::ModuleOp &module, const llvm::SmallVectorImpl &outSigGlobalOps); } // namespace krnl } // namespace onnx_mlir +#endif diff --git a/src/Conversion/KrnlToLLVM/KrnlPrint.cpp b/src/Conversion/KrnlToLLVM/KrnlPrint.cpp index 9639f2ba50..bfccfda8be 100644 --- a/src/Conversion/KrnlToLLVM/KrnlPrint.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlPrint.cpp @@ -55,9 +55,10 @@ class KrnlPrintOpLowering : public ConversionPattern { Value formatSpecPtr = getPtrToGlobalString(formatSpec, loc, rewriter); if (input) - create.llvm.call({}, printfFuncRef, {formatSpecPtr, input}); + create.llvm.call( + {}, printfFuncRef, {formatSpecPtr, input}, /*isVarArg*/ true); else - create.llvm.call({}, printfFuncRef, {formatSpecPtr}); + create.llvm.call({}, printfFuncRef, {formatSpecPtr}, /*isVarArg*/ true); rewriter.eraseOp(op); return success(); diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp index 031e37c126..ef616d12f3 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp @@ -4,7 +4,7 @@ //===------ KrnlToLLVMHelper.hpp ------------------------------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_KRNL_TO_LLVM_H +#define ONNX_MLIR_KRNL_TO_LLVM_H #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -93,3 +94,4 @@ bool isZOS(mlir::ModuleOp module); } // namespace krnl } // namespace onnx_mlir +#endif diff --git a/src/Conversion/KrnlToLLVM/RuntimeAPI.hpp b/src/Conversion/KrnlToLLVM/RuntimeAPI.hpp index 1a3caccb6b..1af24ecec1 100644 --- a/src/Conversion/KrnlToLLVM/RuntimeAPI.hpp +++ b/src/Conversion/KrnlToLLVM/RuntimeAPI.hpp @@ -4,7 +4,7 @@ //===------ RuntimeAPI.hpp - Declaration of the Runtime API ---------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_RUNTIME_API_H +#define ONNX_MLIR_RUNTIME_API_H #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinAttributes.h" @@ -100,3 +101,4 @@ class RuntimeAPIRegistry final { private: ApiRegistry registry; }; +#endif diff --git a/src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp b/src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp index c33dde1555..adde647472 100644 --- a/src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp +++ b/src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_RNN_BASE_CONV_H +#define ONNX_MLIR_RNN_BASE_CONV_H #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/DialectConversion.h" @@ -53,3 +54,4 @@ template std::tuple getActivationPack(RNNOp *op); } // namespace onnx_mlir +#endif diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 5309a01092..16bab29f5c 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -46,6 +46,7 @@ class ONNXEntryPointLowering : public OpRewritePattern { ONNXEntryPointOp::getEntryPointFuncAttrName()); StringRef entryPointName = funcRefAttr.getLeafReference().getValue(); Operation *entryPointOp = module.lookupSymbol(entryPointName); + assert(entryPointOp && "entry point name not found!"); func::FuncOp entryPointFunc = cast(entryPointOp); IntegerAttr numInputsAttr = @@ -222,8 +223,8 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns, // ObjectDetection populateLoweringONNXNonMaxSuppressionOpPattern(patterns, typeConverter, ctx); // Quantization - populateLoweringONNXDynamicQuantizeLinearOpPattern(patterns, typeConverter, ctx); - populateLoweringONNXQuantizeLinearOpPattern(patterns, typeConverter, ctx); + populateLoweringONNXDynamicQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel); + populateLoweringONNXQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel); // Tensor populateLoweringONNXArgMinMaxOpPattern(patterns, typeConverter, ctx); populateLoweringONNXDimOpPattern(patterns, typeConverter, ctx); @@ -490,12 +491,9 @@ void configureOnnxToKrnlLoweringPass(bool reportOnParallel, if (reportOnSimd) { if (!simdIsEnabled) { OnnxToKrnlLoweringConfiguration::defaultSimdComment = "simd is disabled"; - } else { - VectorMachineSupport *vms = - VectorMachineSupport::getGlobalVectorMachineSupport(); - if (!vms->hasSimd()) - OnnxToKrnlLoweringConfiguration::defaultSimdComment = - "cpu with unspecified simd ISA"; + } else if (!VectorMachineSupport::hasSimd()) { + OnnxToKrnlLoweringConfiguration::defaultSimdComment = + "cpu with unspecified simd ISA"; } } if (parallelIsEnabled) diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index c136e1f2c3..a1c4aaa35e 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -42,7 +42,8 @@ Value emitPostProcessingFor(ConversionPatternRewriter &rewriter, Location loc, template static void CheckIfCustomScalarOpIsSupported(Type elementType) { - Type actualElementType = MathBuilder::elementTypeWithVector(elementType); + Type actualElementType = + MathBuilder::elementTypeOfScalarOrVector(elementType); if (mlir::isa(actualElementType)) { if constexpr (std::is_same, CustomScalarOp>::value) return; @@ -55,34 +56,6 @@ static void CheckIfCustomScalarOpIsSupported(Type elementType) { } } -// ============================================================================= -// Template for SIMD analysis - -// Helper for function that support SIMD. -static double simdAnalysis(ArrayRef GOps, ArrayRef GOpsNum, - Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) { - VectorMachineSupport *vms = - VectorMachineSupport::getGlobalVectorMachineSupport(); - return vms->getAvgVectorLength( - GOps, GOpsNum, elementType, vectorizedOpNum, scalarOpNum); -} - -// Default template for ops that do not support SIMD. For the ones that support -// SIMD, we must create an `analyzeSimdFor` template that returns the right -// values. - -static double noSimd(int64_t &vectorizedOpNum, int64_t &scalarOpNum) { - vectorizedOpNum = 0; - scalarOpNum = 1; - return 1.0; -} - -template -double analyzeSimdFor(Type elementType, Operation *op, int64_t &vectorizedOpNum, - int64_t &scalarOpNum) { - return noSimd(vectorizedOpNum, scalarOpNum); -} - // ============================================================================= // Scalar ops handling @@ -92,9 +65,8 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::TrigHyperbolicGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::TrigHyperbolicGop, 1}}; } template <> @@ -103,9 +75,8 @@ struct ScalarOp { using IOp = arith::AddIOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ArithmeticGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}}; } template <> @@ -114,9 +85,8 @@ struct ScalarOp { using IOp = math::AbsIOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::AbsGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::AbsGop, 1}}; } template <> @@ -125,9 +95,8 @@ struct ScalarOp { using IOp = arith::MulIOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::MulGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::MulGop, 1}}; } template <> @@ -136,9 +105,8 @@ struct ScalarOp { using IOp = arith::DivSIOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::DivGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::DivGop, 1}}; } template <> @@ -147,9 +115,8 @@ struct ScalarOp { using IOp = arith::SubIOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ArithmeticGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}}; } template <> @@ -194,9 +161,8 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ExpGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ExpGop, 1}}; } template <> @@ -205,9 +171,8 @@ struct ScalarOp { using IOp = arith::AddIOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ArithmeticGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}}; } template <> @@ -216,9 +181,8 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::TrigGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::TrigGop, 1}}; } template <> @@ -227,9 +191,8 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::LogGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::LogGop, 1}}; } template <> @@ -238,9 +201,8 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::SqrtGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::SqrtGop, 1}}; } template <> @@ -255,9 +217,8 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::CeilGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::CeilGop, 1}}; } template <> @@ -266,9 +227,8 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::FloorGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::FloorGop, 1}}; } template <> @@ -277,9 +237,8 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::TrigGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::TrigGop, 1}}; } template <> @@ -288,9 +247,8 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::PowGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::PowGop, 1}}; } template <> @@ -345,17 +303,14 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { +GenOpMix getGenOpMix(Type t, Operation *op) { StringRef approximate = dyn_cast(op).getApproximate(); if (approximate.equals_insensitive("none")) - return simdAnalysis( - {GenericOps::ArithmeticGop, GenericOps::ErfGop, GenericOps::MulGop}, - {1, 1, 3}, t, von, son); + return {{GenericOps::ArithmeticGop, 1}, {GenericOps::ErfGop, 1}, + {GenericOps::MulGop, 3}}; if (approximate.equals_insensitive("tanh")) - return simdAnalysis({GenericOps::ArithmeticGop, GenericOps::MulGop, - GenericOps::TrigHyperbolicGop}, - {2, 5, 1}, t, von, son); + return {{GenericOps::ArithmeticGop, 2}, {GenericOps::MulGop, 5}, + {GenericOps::TrigHyperbolicGop, 1}}; llvm_unreachable("approximate should be only none or tanh"); } @@ -484,11 +439,9 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::ArithmeticGop, GenericOps::ExpGop, GenericOps::DivGop}, - {2, 2, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 2}, {GenericOps::ExpGop, 2}, + {GenericOps::DivGop, 1}}; } template <> @@ -518,11 +471,9 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::ArithmeticGop, GenericOps::ExpGop, GenericOps::DivGop}, - {2, 2, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 2}, {GenericOps::ExpGop, 2}, + {GenericOps::DivGop, 1}}; } template <> @@ -552,11 +503,9 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::ArithmeticGop, GenericOps::ExpGop, GenericOps::DivGop}, - {2, 1, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 2}, {GenericOps::ExpGop, 1}, + {GenericOps::DivGop, 1}}; } template <> @@ -585,10 +534,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::ArithmeticGop, GenericOps::MulGop}, {2, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 3}, {GenericOps::MulGop, 1}}; } template <> @@ -629,12 +576,10 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::ArithmeticGop, GenericOps::MulGop, GenericOps::CompareGop, - GenericOps::SelectGop, GenericOps::ExpGop}, - {1, 1, 1, 1, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}, {GenericOps::MulGop, 1}, + {GenericOps::CompareGop, 1}, {GenericOps::SelectGop, 1}, + {GenericOps::ExpGop, 1}}; } template <> @@ -667,9 +612,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ArithmeticGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}}; } template <> @@ -693,11 +637,9 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::CompareGop, GenericOps::SelectGop, GenericOps::MulGop}, - {1, 1, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::CompareGop, 1}, {GenericOps::SelectGop, 1}, + {GenericOps::MulGop, 1}}; } template <> @@ -728,11 +670,9 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::CompareGop, GenericOps::SelectGop, GenericOps::MulGop}, - {1, 1, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::CompareGop, 1}, {GenericOps::SelectGop, 1}, + {GenericOps::MulGop, 1}}; } template <> @@ -760,12 +700,10 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::CompareGop, GenericOps::SelectGop, GenericOps::MulGop, - GenericOps::ArithmeticGop, GenericOps::ExpGop}, - {1, 1, 2, 1, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::CompareGop, 1}, {GenericOps::SelectGop, 1}, + {GenericOps::MulGop, 2}, {GenericOps::ArithmeticGop, 1}, + {GenericOps::ExpGop, 1}}; } template <> @@ -802,9 +740,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::DivGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::DivGop, 1}}; } template <> @@ -829,11 +766,9 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::ExpGop, GenericOps::ArithmeticGop, GenericOps::LogGop}, - {1, 1, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ExpGop, 1}, {GenericOps::ArithmeticGop, 1}, + {GenericOps::LogGop, 1}}; } template <> @@ -860,11 +795,9 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::AbsGop, GenericOps::ArithmeticGop, GenericOps::DivGop}, - {1, 1, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::AbsGop, 1}, {GenericOps::ArithmeticGop, 1}, + {GenericOps::DivGop, 1}}; } template <> @@ -891,10 +824,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::CompareGop, GenericOps::SelectGop}, {2, 2}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::CompareGop, 2}, {GenericOps::SelectGop, 2}}; } template <> @@ -914,7 +845,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, // ConstantOp 0, // %Y) Value plusSelect; - if (create.math.isUnsignedIntegerWithVector(elementType)) { + if (create.math.isScalarOrVectorUnsignedInteger(elementType)) { // Unsigned integers are by definition positive. plusSelect = one; } else { @@ -935,9 +866,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ErfGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ErfGop, 1}}; } //===----------------------------------------------------------------------===// @@ -950,9 +880,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ArithmeticGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}}; } template <> @@ -979,9 +908,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ArithmeticGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}}; } template <> @@ -1009,9 +937,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ArithmeticGop}, {1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}}; } template <> @@ -1172,10 +1099,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::RemGop, GenericOps::CopySignGop}, {1, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::RemGop, 1}, {GenericOps::CopySignGop, 1}}; } template <> @@ -1188,7 +1113,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, MultiDialectBuilder create(rewriter, loc); // TODO: here we assume fmod=1, what should if that is not the case? - if (create.math.isFloatWithVector(elementType)) { + if (create.math.isScalarOrVectorFloat(elementType)) { // fmod is always 1. Behavior is like numpy.fmod. // The sign of the remainder is the same as the dividend. Value rem = create.math.rem(dividend, divisor); @@ -1201,7 +1126,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, return create.math.copySign(rem, dividend); #endif } - if (create.math.isIntegerWithVector(elementType)) { + if (create.math.isScalarOrVectorInteger(elementType)) { // "math.rem" returns "minus" for minus dividend and "plus or zero" for plus // dividend. We call the math.rem's return value "mathRemainder". However // onnx.ModOp should return "minus" for minus divisor and "plus or zero" for @@ -1242,15 +1167,6 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Value answer = create.math.select(needAdjust, adjustedRemainder, mathRemainder); -#ifdef DEBUG_ONNX_MOD - create.krnl.printf("XXXX emitScalarOpFor: dividend=", dividend); - create.krnl.printf(", divisor=", divisor); - create.krnl.printf(", mathReminder=", mathRemainder); - create.krnl.printf(", adjustedReminder=", adjustedRemainder); - create.krnl.printf(", Answer=", answer); - create.krnl.printf("\n"); -#endif - return answer; } llvm_unreachable("unsupported element type"); @@ -1266,10 +1182,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::ArithmeticGop, GenericOps::DivGop}, {1, 1}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}, {GenericOps::DivGop, 1}}; } template <> @@ -1291,12 +1205,10 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis( - {GenericOps::ArithmeticGop, GenericOps::MulGop, GenericOps::CompareGop, - GenericOps::SelectGop, GenericOps::FloorGop}, - {4, 2, 3, 3, 2}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2}, + {GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3}, + {GenericOps::FloorGop, 2}}; } template <> @@ -1306,43 +1218,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Value x = scalarOperands[0]; MultiDialectBuilder create(rewriter, loc); CheckIfCustomScalarOpIsSupported(elementType); - // Use numpy algorithm for rint as follows. - // ``` - // double y, r; - // y = npy_floor(x); - // r = x - y; - // - // if (r > 0.5) { - // y += 1.0; - // } - // - // /* Round to nearest even */ - // if (r == 0.5) { - // r = y - 2.0*npy_floor(0.5*y); - // if (r == 1.0) { - // y += 1.0; - // } - // } - // return y; - // ``` - Value one = create.math.constant(elementType, 1.0); - Value two = create.math.constant(elementType, 2.0); - Value half = create.math.constant(elementType, 0.5); - Value y = create.math.floor(x); - Value r = create.math.sub(x, y); - // r > 0.5 - Value rGreaterThanHalf = create.math.sgt(r, half); - Value y1 = create.math.select(rGreaterThanHalf, create.math.add(y, one), y); - // r == 0.5: round to nearest even. - Value y2 = create.math.mul(half, y); - y2 = create.math.floor(y2); - y2 = create.math.mul(y2, two); - Value rr = create.math.sub(y, y2); - Value rrEqualOne = create.math.eq(rr, one); - y2 = create.math.select(rrEqualOne, create.math.add(y, one), y); - - Value rEqualHalf = create.math.eq(r, half); - return create.math.select(rEqualHalf, y2, y1); + return create.math.round(x); } //===----------------------------------------------------------------------===// @@ -1355,9 +1231,8 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - return simdAnalysis({GenericOps::ArithmeticGop}, {2}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 2}}; } template <> @@ -1385,14 +1260,9 @@ struct ScalarOp { }; template <> -double analyzeSimdFor( - Type t, Operation *op, int64_t &von, int64_t &son) { - // Right now, MLIR vector:splat does not support unsigned int types. - // Thus we must disable SIMD here for now. - return noSimd(von, son); - // return simdAnalysis({GenericOps::ArithmeticGop, GenericOps::MulGop, - // GenericOps::ConversionGop}, - // {1, 1, 2}, t, von, son); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}, {GenericOps::MulGop, 1}, + {GenericOps::ConversionGop, 2}}; } template <> @@ -1421,52 +1291,6 @@ Value emitScalarOpFor( using MDBuilder = MultiDialectBuilder; -// Return unrolled vector length; no simd -> return 0; -// collapsedLiteralSize is ignored when we can collapse every loop iterations as -// we then rely on padding of the allocated memory to enable arbitrary output -// array simdization. When partial simd is requested, then we must ensure that -// the collapsed loop cumulative static size is a multiple of the VL. -template -int64_t canBeVectorized(ShapeHelperType &shapeHelper, MDBuilder &create, - Operation *op, MemRefType memRefType, int64_t collapsedInnermostLoops, - int64_t &estimatedSimdLoopTripCount) { - estimatedSimdLoopTripCount = 0; // Initially assume no SIMD. - int64_t simdUnroll; - int64_t uVL = 0; - // SIMD is enabled for this operation, test if profitable. - Type elementType = memRefType.getElementType(); - int64_t vectorizedOpNum, scalarOpNum; - double avgSimdWidth = analyzeSimdFor( - elementType, op, vectorizedOpNum, scalarOpNum); - if (avgSimdWidth < 1.5) { - LLVM_DEBUG(llvm::dbgs() << " simd disabled: avg simd width " - << avgSimdWidth << " too small\n"); - return 0; - } - // Determine empirical unroll factor. - VectorMachineSupport *vms = - VectorMachineSupport::getGlobalVectorMachineSupport(); - - int64_t vrNum = vms->VectorRegisterNum(); - if (vectorizedOpNum >= vrNum / 2) - simdUnroll = 1; // TODO, it would appear to be beneficial to always have 2. - else if (vectorizedOpNum >= vrNum / 4) - simdUnroll = 4; - else - simdUnroll = 8; - uVL = create.vec.computeSuitableUnrollFactor(vms, memRefType, - shapeHelper.getOutputDims(), collapsedInnermostLoops, simdUnroll, - /*canPad*/ true, estimatedSimdLoopTripCount); - LLVM_DEBUG({ - if (uVL) - llvm::dbgs() << " simd enabled with vector length " << uVL << "\n"; - else - LLVM_DEBUG( - llvm::dbgs() << " simd disabled, no feasible with unroll factor\n"); - }); - return uVL; -} - //===----------------------------------------------------------------------===// // SIMD code gen for kernels where data can be partially or fully flattened. //===----------------------------------------------------------------------===// @@ -1477,23 +1301,29 @@ static LogicalResult getPartiallyFlattenedSimdCode( ONNXBroadcastOpShapeHelper *shapeHelper, Operation *op, MemRefType outputMemRefType, ValueRange operands, int64_t alignment, int64_t VL, int64_t collapsedInnermostLoops, bool ruledOutBroadcast, - bool isUnaryOp, bool enableParallel) { + bool isUnaryOp, bool simdOnly, bool enableParallel) { Type outputElementType = outputMemRefType.getElementType(); unsigned numArgs = op->getNumOperands(); LLVM_DEBUG(llvm::dbgs() << " partial SIMD code for elementwise op " << op->getName() << " flattening " << collapsedInnermostLoops << " inner dims\n"); - - // generate SIMD code of VL elements per vector. + // If fully collapse the loop, then we can allocate more data and we don't + // care if we compute a few more values... set simdOnly to true then + // regardless of whether the dims allow us to do so or not. + if (collapsedInnermostLoops == (int64_t)outputMemRefType.getRank()) { + LLVM_DEBUG(llvm::dbgs() << " fully flattened, set simdOnly to true\n"); + simdOnly = true; + } + // Generate SIMD code of VL elements per vector. IndexExprScope allocScope(create.vec, shapeHelper->getScope()); DimsExpr outputDims; getIndexExprList(shapeHelper->getOutputDims(), outputDims); // Alloc memory with padding for SIMD. // For the moment, its ok to go here; if we truly have partial flattening of // the simd code, then we only do it with static memref size that are - // multiples of VL * simdUnroll, so there should be no padding anyway. This + // multiples of VL * unrollVL, so there should be no padding anyway. This // will change if we do partial flattening with non-multiple of VL * - // simdUnroll. + // unrollVL. Value alloc = create.mem.alignedAllocWithSimdPadding( outputMemRefType, outputDims, VL, alignment); // Create flat inputs in the last innerDinNum dims. @@ -1506,8 +1336,13 @@ static LogicalResult getPartiallyFlattenedSimdCode( } DimsExpr operDims, flattenOperDims; create.krnlIE.getShapeAsSymbols(oper, operDims); + // Because we fully fuse 1x1x128xf32 and 128xf32, the + // collapsedInnermostLoops may be higher than the rank of this input. Adjust + // collapsedInnermostLoops accordingly for the flatten below. + int64_t currRank = operDims.size(); + int64_t currCollapsedNum = std::min(collapsedInnermostLoops, currRank); Value flatOper = create.mem.reshapeToFlatInnermost( - oper, operDims, flattenOperDims, collapsedInnermostLoops); + oper, operDims, flattenOperDims, currCollapsedNum); flatOperands.emplace_back(flatOper); } @@ -1515,111 +1350,130 @@ static LogicalResult getPartiallyFlattenedSimdCode( int64_t rank = outputDims.size() - collapsedInnermostLoops + 1; LLVM_DEBUG( llvm::dbgs() << "SIMD partial flatten with loop rank " << rank << "\n"); - int64_t flattenedDim = rank - 1; SmallVector flattenedOutputDims; Value flatAlloc = create.mem.reshapeToFlatInnermost( alloc, outputDims, flattenedOutputDims, collapsedInnermostLoops); - // Create loop iteration (flattened to output dim - inner dim + 1) with inner - // one and blocked by mVL. - ValueRange loopDef = create.krnl.defineLoops(rank); - ValueRange blockedLoopDef = create.krnl.block(loopDef[flattenedDim], VL); - SmallVector optimizedLoopDef; - for (int64_t r = 0; r < rank - 1; ++r) { - optimizedLoopDef.emplace_back(loopDef[r]); - } - optimizedLoopDef.emplace_back(blockedLoopDef[0]); - // Create the vector type to operate over. - VectorType vecElementType = VectorType::get({VL}, outputElementType); + + // Create loop iteration, rank-1, all but the flattened innermost [simd] loop. + int64_t outerLoopRank = rank - 1; + ValueRange loopDef = create.krnl.defineLoops(outerLoopRank); // Iterate only over the blocks. - SmallVector lbs(rank, LiteralIndexExpr(0)); + IndexExpr zero = LitIE(0); + DimsExpr lbs(outerLoopRank, zero); + DimsExpr ubs = flattenedOutputDims; + IndexExpr simdUb = ubs.pop_back_val(); // Remove flattened ub. + bool useParallelInSimdLoop = false; if (enableParallel) { int64_t parId; - if (findSuitableParallelDimension( - lbs, flattenedOutputDims, 0, std::min((int64_t)2, rank), parId)) { - create.krnl.parallel(optimizedLoopDef[parId]); - onnxToKrnlParallelReport(op, true, parId, lbs[parId], - flattenedOutputDims[parId], "elementwise simd partially flattened"); + if (outerLoopRank > 1) { + // Outer loop parallelism. + if (findSuitableParallelDimension( + lbs, ubs, 0, std::min((int64_t)2, outerLoopRank), parId)) { + create.krnl.parallel(loopDef[parId]); + onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], + "outer-loop of elementwise simd partially flattened"); + } else { + onnxToKrnlParallelReport(op, false, -1, -1, + "not enough work in outermost-loops of elementwise simd partially " + "flattened"); + } } else { - onnxToKrnlParallelReport(op, false, -1, -1, - "no dim with enough work in elementwise simd partially flattened"); + // SIMD loop parallelism. + DimsExpr simdLbs = {zero}, simdUbs = {simdUb}; + if (findSuitableParallelDimension( + simdLbs, simdUbs, 0, 1, parId, VL * 32)) { + assert(parId == 0 && "expected loop zero to be parallelized"); + useParallelInSimdLoop = true; + onnxToKrnlParallelReport(op, true, parId, zero, simdUb, + "innermost-loop of elementwise simd partially flattened"); + } else { + onnxToKrnlParallelReport(op, false, -1, -1, + "not enough work in innermost-loop of elementwise simd partially " + "flattened"); + } } } - create.krnl.iterateIE(loopDef, optimizedLoopDef, lbs, flattenedOutputDims, - [&](KrnlBuilder &ck, ValueRange loopInd) { - MultiDialectBuilder create(ck); - SmallVector outputAccessExprs; - getIndexExprList(loopInd, outputAccessExprs); - - llvm::SmallVector loadedVals; - // Load all the values - for (int64_t i = 0; i < (int64_t)flatOperands.size(); ++i) { - Value flatOper = flatOperands[i]; - if (isNoneValue(flatOper)) { - // None, just pass it on unmodified. - loadedVals.emplace_back(flatOper); + create.krnl.iterateIE( + loopDef, loopDef, lbs, ubs, [&](KrnlBuilder &ck, ValueRange loopInd) { + MultiDialectBuilder create(ck); + // LoopInd has the current indices for all but the innermost dim. Since + // we expect here the entire innermost loop iteration in one go, the + // innermost loop starts at zero. Add here to the list of Dim symbols. + SmallVector outputAccessExprs = DimListIE(loopInd); + outputAccessExprs.emplace_back(zero); + + // Have to produce the list of input values and their access functions. + llvm::SmallVector inputs = flatOperands; + llvm::SmallVector inputAFs; + for (int64_t i = 0; i < (int64_t)inputs.size(); ++i) { + Value input = inputs[i]; + // Define the access function for each of the inputs. + DimsExpr inputAF; + // Check if we have a none value. + if (isNoneValue(input)) { + // Have one, pass the none value with empty AF. + inputAFs.emplace_back(inputAF); continue; } - MemRefType memRefType = - mlir::dyn_cast(flatOper.getType()); - assert(memRefType && "expected memref"); - VectorType vecType = - VectorType::get({VL}, memRefType.getElementType()); - if (hasOneElementInInnermostDims(flatOper, 1)) { - // If its a scalar, do a scalar load and splat. - llvm::SmallVector scalarAccessFct; - if (hasOneElement(flatOper)) { - // Not flattened, with only 1 dims, just put zeros as needed. - int64_t scalarRank = - mlir::dyn_cast(flatOper.getType()).getRank(); - for (int r = 0; r < scalarRank; ++r) - scalarAccessFct.emplace_back(LiteralIndexExpr(0)); - - } else { - // Was flattened, with non 1 dims, use get access expr. - LogicalResult res = - shapeHelper->getAccessExprs(flatOper, i, outputAccessExprs, - scalarAccessFct, /*flattened*/ true, ruledOutBroadcast); - assert(succeeded(res) && "Could not compute access indices"); - } - Value loadedVal = create.krnl.loadIE(flatOper, scalarAccessFct); - Value splatValue = create.vec.splat(vecType, loadedVal); - loadedVals.emplace_back(splatValue); - } else { - llvm::SmallVector loadAccessFct; - LogicalResult res = - shapeHelper->getAccessExprs(flatOper, i, outputAccessExprs, - loadAccessFct, /*flattened*/ true, ruledOutBroadcast); - assert(succeeded(res) && "Could not compute access indices"); - Value loadedVal = - create.vec.loadIE(vecType, flatOper, loadAccessFct, {}); - loadedVals.emplace_back(loadedVal); - } - } - Value finalResult; - if (isUnaryOp) { - // For unary op, we through all operands at once as the other ones are - // scalars / none values. - finalResult = emitScalarOpFor( - rewriter, create.getLoc(), op, vecElementType, loadedVals); - } else { - // For non-unary ops, each op is a flattened array that need to be - // processed; process the two first ones, and then "accumulate" one - // value at a time. Use the first operand as temporary result. - Value accumulated = loadedVals[0]; - // Iterate over the remaining operands. - for (unsigned i = 1; i < numArgs; ++i) { - Value next = loadedVals[i]; - // Fold. - accumulated = emitScalarOpFor(rewriter, create.getLoc(), - op, vecElementType, {accumulated, next}); + // We have a memref, analyze which kind. + MemRefType inputType = mlir::dyn_cast(input.getType()); + assert(inputType && "expected memref"); + // Check if we have a scalar. + if (hasOneElement(input)) { + // Not flattened, with only 1 dims, just put zeros as needed. + int64_t inputRank = inputType.getRank(); + for (int r = 0; r < inputRank; ++r) + inputAF.emplace_back(zero); + inputAFs.emplace_back(inputAF); + continue; } - // Postprocessing (dummy op if none). - finalResult = emitPostProcessingFor( - rewriter, create.getLoc(), op, vecElementType, accumulated); + // We have a regular access. + LogicalResult res = + shapeHelper->getAccessExprs(input, i, outputAccessExprs, inputAF, + /*flattened*/ true, ruledOutBroadcast); + assert(succeeded(res) && "Could not compute access indices"); + inputAFs.emplace_back(inputAF); } - // Store result in the resulting array. - create.vec.store(finalResult, flatAlloc, loopInd); - }); + // Produce the list of outputs and output AFs + Value output = flatAlloc; + DimsExpr outputAF = outputAccessExprs; + + create.krnl.simdIterateIE(zero, SymIE(simdUb), VL, simdOnly, + useParallelInSimdLoop, inputs, inputAFs, {output}, {outputAF}, + [&](KrnlBuilder &kb, ArrayRef inputVals, + SmallVectorImpl &resVals, int64_t VL) { + MultiDialectBuilder create(kb); + Type currElementType = outputElementType; + if (VL > 1) + currElementType = VectorType::get({VL}, outputElementType); + Value res; + if (isUnaryOp) { + // For unary op, we through all operands at once as the other + // ones are scalars / none values. + res = emitScalarOpFor( + rewriter, create.getLoc(), op, currElementType, inputVals); + } else { + // For non-unary ops, each op is a flattened array that need to + // be processed; process the two first ones, and then + // "accumulate" one value at a time. Use the first operand as + // temporary result. + Value accumulated = inputVals[0]; + // Iterate over the remaining operands. + for (unsigned i = 1; i < numArgs; ++i) { + Value next = inputVals[i]; + // Fold. + accumulated = + emitScalarOpFor(rewriter, create.getLoc(), op, + currElementType, {accumulated, next}); + } + // Postprocessing (dummy op if none). + res = emitPostProcessingFor(rewriter, create.getLoc(), + op, currElementType, accumulated); + } + resVals.emplace_back(res); + }); // SIMD kernel. + }); // Outer loops. + rewriter.replaceOp(op, alloc); return success(); } @@ -1814,6 +1668,11 @@ bool OpFusionHelper::isControlFlowValidForFusion( // function by fold function. bool OpFusionHelper::areInputsValidForFusion( Operation *useOp, Operation *defOp, DimAnalysis *dimAnalysis) { + // Do not fuse ops with scalar tensors. + if (llvm::all_of( + useOp->getOperands(), [](Value v) { return isScalarTensor(v); })) + return false; + // Elementwise unary operation is always fusible if (useOp->getOperands().size() == 1) return true; @@ -2065,7 +1924,11 @@ struct ONNXElementwiseUnaryOpLowering "Failed to convert type to MemRefType"); MemRefType outputMemRefType = mlir::cast(convertedType); int64_t outputRank = outputMemRefType.getRank(); - Type elementType = outputMemRefType.getElementType(); + Type outputElementType = outputMemRefType.getElementType(); + + // In unary, we don't have any broadcast, and thus our target is to fully + // collapse the loop to a 1D loop. + int64_t collapsedInnermostLoops = outputRank; // Shape helper. MDBuilder create(rewriter, loc); @@ -2078,21 +1941,23 @@ struct ONNXElementwiseUnaryOpLowering bool isScalar = hasAllScalarValues(operands); // SIMD is enabled for this operation, test if desired and feasible if (enableSIMD && !isScalar && !hasNonIdentityLayout(operands)) { - int64_t estimatedSimdLoopTripCount; - int64_t uVL = canBeVectorized( - shapeHelper, create, op, outputMemRefType, outputRank, - estimatedSimdLoopTripCount); - if (uVL > 0) { - onnxToKrnlSimdReport(op, /*successful*/ true, uVL, - estimatedSimdLoopTripCount, "unary fully flattened"); + int64_t simdLoopStaticTripCount; + bool simdOnly, canOverCompute = true; + GenOpMix mix = getGenOpMix(outputElementType, op); + int64_t totVL = + computeSuitableUnrollFactor(outputMemRefType, collapsedInnermostLoops, + mix, canOverCompute, simdLoopStaticTripCount, simdOnly); + if (totVL > 1) { + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "unary fully flattened"); return getPartiallyFlattenedSimdCode(rewriter, create, &shapeHelper, op, outputMemRefType, operands, alignment, - uVL, /*collapsedInnermostLoop*/ outputRank, - /*ruleOutBroadcast*/ true, /*unary*/ true, enableParallel); + totVL, collapsedInnermostLoops, /*ruleOutBroadcast*/ true, + /*unary*/ true, simdOnly, enableParallel); } - onnxToKrnlSimdReport(op, /*successful*/ false, 0, - estimatedSimdLoopTripCount, + onnxToKrnlSimdReport(op, /*successful*/ false, 0, simdLoopStaticTripCount, "no simd in unary because could not find beneficial VL"); + } else { onnxToKrnlSimdReport(op, /*successful*/ false, 0, 0, "no simd in unary because scalar/layouts"); @@ -2143,7 +2008,7 @@ struct ONNXElementwiseUnaryOpLowering args.emplace_back(loadedVal); } auto loweredOpResult = emitScalarOpFor( - rewriter, loc, op, elementType, args); + rewriter, loc, op, outputElementType, args); loweredOpResult = opFusionHelper.emitFuseOps(loweredOpResult, alloc, loopInd); // Store result in the resulting array. @@ -2165,7 +2030,7 @@ struct ONNXElementwiseUnaryOpLowering args.emplace_back(loadedVal); } auto loweredOpResult = emitScalarOpFor( - rewriter, loc, op, elementType, args); + rewriter, loc, op, outputElementType, args); loweredOpResult = opFusionHelper.emitFuseOps(loweredOpResult, alloc); // Store result in the resulting array. create.krnl.store(loweredOpResult, alloc); @@ -2175,7 +2040,7 @@ struct ONNXElementwiseUnaryOpLowering opFusionHelper.replaceOrEraseONNXOps(alloc); return success(); } -}; // namespace onnx_mlir +}; //===----------------------------------------------------------------------===// // Element-wise binary ops lowering to Krnl dialect. @@ -2219,7 +2084,7 @@ struct ONNXElementwiseBinaryOpLowering "Failed to convert type to MemRefType"); MemRefType outputMemRefType = mlir::cast(convertedType); Type outputElementType = outputMemRefType.getElementType(); - uint64_t outputRank = outputMemRefType.getRank(); + int64_t outputRank = outputMemRefType.getRank(); // Shape helper. MDBuilder create(rewriter, loc); @@ -2254,25 +2119,25 @@ struct ONNXElementwiseBinaryOpLowering // SIMD is enabled for this operation, test if desired and feasible if (enableSIMD && !isScalar && hasManageableBroadcast && !hasNonIdentityLayout(operands)) { - int64_t estimatedSimdLoopTripCount; - int64_t uVL = - canBeVectorized( - shapeHelper, create, op, outputMemRefType, - collapsedInnermostLoops, estimatedSimdLoopTripCount); - if (uVL > 0) { - if (collapsedInnermostLoops == (int64_t)outputRank) - onnxToKrnlSimdReport(op, /*successful*/ true, uVL, - estimatedSimdLoopTripCount, "binary fully flattened"); + int64_t simdLoopStaticTripCount; + bool simdOnly, canOverCompute = collapsedInnermostLoops == outputRank; + GenOpMix mix = getGenOpMix(outputElementType, op); + int64_t totVL = + computeSuitableUnrollFactor(outputMemRefType, collapsedInnermostLoops, + mix, canOverCompute, simdLoopStaticTripCount, simdOnly); + if (totVL > 1) { + if (collapsedInnermostLoops == outputRank) + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "binary fully flattened"); else - onnxToKrnlSimdReport(op, /*successful*/ true, uVL, - estimatedSimdLoopTripCount, "binary with manageable broadcast"); + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "binary with manageable broadcast"); return getPartiallyFlattenedSimdCode(rewriter, create, &shapeHelper, op, outputMemRefType, operands, alignment, - uVL, collapsedInnermostLoops, hasNoBroadcast, - /*unary*/ false, enableParallel); + totVL, collapsedInnermostLoops, hasNoBroadcast, + /*unary*/ false, simdOnly, enableParallel); } - onnxToKrnlSimdReport(op, /*successful*/ false, 0, - estimatedSimdLoopTripCount, + onnxToKrnlSimdReport(op, /*successful*/ false, 0, simdLoopStaticTripCount, "no simd in binary because no beneficial VL"); } else { onnxToKrnlSimdReport(op, /*successful*/ false, 0, 0, @@ -2299,7 +2164,7 @@ struct ONNXElementwiseBinaryOpLowering if (enableParallel) { int64_t parId; if (findSuitableParallelDimension( - lbs, ubs, 0, std::min((uint64_t)2, outputRank), parId)) { + lbs, ubs, 0, std::min((int64_t)2, outputRank), parId)) { create.krnl.parallel(loopDef[parId]); onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], "elementwise binary not simdized"); @@ -2396,7 +2261,7 @@ struct ONNXElementwiseVariadicOpLowering "Failed to convert type to MemRefType"); MemRefType outputMemRefType = mlir::cast(convertedType); Type outputElementType = outputMemRefType.getElementType(); - uint64_t outputRank = outputMemRefType.getRank(); + int64_t outputRank = outputMemRefType.getRank(); // Shape helper. MDBuilder create(rewriter, loc); @@ -2429,25 +2294,25 @@ struct ONNXElementwiseVariadicOpLowering if (enableSIMD && !isScalar && hasManageableBroadcast && !hasNonIdentityLayout(operands)) { // SIMD is enabled for this operation, test if desired and feasible - int64_t estimatedSimdLoopTripCount; - int64_t uVL = - canBeVectorized( - shapeHelper, create, op, outputMemRefType, - collapsedInnermostLoops, estimatedSimdLoopTripCount); - if (uVL > 0) { - if (collapsedInnermostLoops == (int64_t)outputRank) - onnxToKrnlSimdReport(op, /*successful*/ true, uVL, - estimatedSimdLoopTripCount, "variadic fully flattened"); + int64_t simdLoopStaticTripCount; + bool simdOnly, canOverCompute = collapsedInnermostLoops == outputRank; + GenOpMix mix = getGenOpMix(outputElementType, op); + int64_t totVL = + computeSuitableUnrollFactor(outputMemRefType, collapsedInnermostLoops, + mix, canOverCompute, simdLoopStaticTripCount, simdOnly); + if (totVL > 1) { + if (collapsedInnermostLoops == outputRank) + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "variadic fully flattened"); else - onnxToKrnlSimdReport(op, /*successful*/ true, uVL, - estimatedSimdLoopTripCount, "variadic with manageable broadcast"); + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "variadic with manageable broadcast"); return getPartiallyFlattenedSimdCode(rewriter, create, &shapeHelper, op, outputMemRefType, operands, alignment, - uVL, collapsedInnermostLoops, hasNoBroadcast, - /*unary*/ false, enableParallel); + totVL, collapsedInnermostLoops, hasNoBroadcast, + /*unary*/ false, simdOnly, enableParallel); } - onnxToKrnlSimdReport(op, /*successful*/ false, 0, - estimatedSimdLoopTripCount, + onnxToKrnlSimdReport(op, /*successful*/ false, 0, simdLoopStaticTripCount, "no simd in variadic because no beneficial VL"); } else { onnxToKrnlSimdReport(op, /*successful*/ false, 0, 0, @@ -2474,7 +2339,7 @@ struct ONNXElementwiseVariadicOpLowering if (enableParallel) { int64_t parId; if (findSuitableParallelDimension( - lbs, ubs, 0, std::min((uint64_t)2, outputRank), parId)) { + lbs, ubs, 0, std::min((int64_t)2, outputRank), parId)) { create.krnl.parallel(loopDef[parId]); onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], "elementwise variadic not simdized"); @@ -2503,8 +2368,8 @@ struct ONNXElementwiseVariadicOpLowering // Obtain the next operand. SmallVector oprdAccessExprs; LogicalResult res = shapeHelper.getAccessExprs(operands[i], i, - outputAccessExprs, oprdAccessExprs, /*flattened dims*/ false, - hasNoBroadcast); + outputAccessExprs, oprdAccessExprs, + /*flattened dims*/ false, hasNoBroadcast); assert(succeeded(res) && "Could not compute access indices"); Value next = createKrnl.loadIE(operands[i], oprdAccessExprs); // Fold. @@ -2574,7 +2439,7 @@ struct ONNXWhereOpLowering : public ConversionPattern { assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); MemRefType outputMemRefType = mlir::cast(convertedType); - uint64_t outputRank = outputMemRefType.getRank(); + int64_t outputRank = outputMemRefType.getRank(); ONNXWhereOpAdaptor operandAdaptor(operands); // Shape helper. @@ -2597,7 +2462,7 @@ struct ONNXWhereOpLowering : public ConversionPattern { if (enableParallel) { int64_t parId; if (findSuitableParallelDimension( - lbs, ubs, 0, std::min((uint64_t)2, outputRank), parId)) { + lbs, ubs, 0, std::min((int64_t)2, outputRank), parId)) { create.krnl.parallel(loopDef[parId]); onnxToKrnlParallelReport( op, true, parId, lbs[parId], ubs[parId], "where op not simdized"); diff --git a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp index 87af6088ad..574638d510 100644 --- a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp +++ b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp @@ -222,7 +222,7 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { }); } - void computeTileSizeForMatVectProduct(Operation *op, int64_t mVL, + void computeTileSizeForMatVectProduct(Operation *op, int64_t VL, DimIndexExpr dimI, DimIndexExpr dimJ, DimIndexExpr dimK, int64_t &iRegTile, int64_t &jRegTile, int64_t &kRegTile, bool &simdize) const { @@ -232,21 +232,21 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { "no simd because disabled for mat * vec"); // Default values. - // Right can only tile i and k by (possibly distinct) multiple of mVL. - iRegTile = 2 * mVL; // SIMD dim during multi-reduction. + // Right can only tile i and k by (possibly distinct) multiple of VL. + iRegTile = 2 * VL; // SIMD dim during multi-reduction. jRegTile = 1; - kRegTile = 16 * mVL; // SIMD dim during multiplication. + kRegTile = 16 * VL; // SIMD dim during multiplication. if (dimK.isLiteral()) { int64_t constK = dimK.getLiteral(); // Register tile in the I Dim is really for the reduction. The - // computations will be further tiled to a multiple of mVL inside + // computations will be further tiled to a multiple of VL inside // krnl.matmul. - kRegTile = (constK / mVL) * mVL; // largest multiple - if (kRegTile > 64 * mVL) { - kRegTile = 64 * mVL; + kRegTile = (constK / VL) * VL; // largest multiple + if (kRegTile > 64 * VL) { + kRegTile = 64 * VL; LLVM_DEBUG({ llvm::dbgs() << "MatMul Vec: cap tiling k\n"; }); - } else if (kRegTile < mVL) { + } else if (kRegTile < VL) { // Not enough data, can only support i/k reg tile of 4. LLVM_DEBUG({ llvm::dbgs() << "MatMul Vec: disable k\n"; }); simdize = false; @@ -258,8 +258,8 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { if (dimI.isLiteral()) { int64_t constI = dimI.getLiteral(); if (constI < iRegTile) { - iRegTile = (constI / mVL) * mVL; // largest multiple - if (iRegTile < mVL) { + iRegTile = (constI / VL) * VL; // largest multiple + if (iRegTile < VL) { // Not enough data, can only support i/k reg tile of 4. LLVM_DEBUG({ llvm::dbgs() << "MatMul Vec: disable i\n"; }); simdize = false; @@ -307,9 +307,9 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { bool isMatVectorProduct = !DISABLE_MAT_VEC_PRODUCT && dimJ.isLiteral() && dimJ.getLiteral() == 1; if (isMatVectorProduct) { - int64_t mVL = create.vec.getMachineVectorLength(elementType); + int64_t archVL = create.vec.getArchVectorLength(elementType); computeTileSizeForMatVectProduct( - op, mVL, dimI, dimJ, dimK, iRegTile, jRegTile, kRegTile, simdize); + op, archVL, dimI, dimJ, dimK, iRegTile, jRegTile, kRegTile, simdize); } else { computeTileSizeForMatMatProduct( op, dimI, dimJ, dimK, iRegTile, jRegTile, kRegTile, simdize); @@ -391,9 +391,9 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { bool isMatVectorProduct = !DISABLE_MAT_VEC_PRODUCT && dimJ.isLiteral() && dimJ.getLiteral() == 1; if (isMatVectorProduct) { - int64_t mVL = create.vec.getMachineVectorLength(elementType); + int64_t archVL = create.vec.getArchVectorLength(elementType); computeTileSizeForMatVectProduct( - op, mVL, dimI, dimJ, dimK, iRegTile, jRegTile, kRegTile, simdize); + op, archVL, dimI, dimJ, dimK, iRegTile, jRegTile, kRegTile, simdize); } else { computeTileSizeForMatMatProduct( op, dimI, dimJ, dimK, iRegTile, jRegTile, kRegTile, simdize); diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index 6c1824fe3a..8702377291 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -16,6 +16,7 @@ #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Support/SmallVectorHelper.hpp" #define DEBUG_TYPE "lowering-to-krnl" #define DEBUG_FORCE_SHUFFLE_REDUCTION 0 @@ -24,193 +25,113 @@ using namespace mlir; namespace onnx_mlir { -// support - -// Until num, inclusive. Negative numbers count from the back of the vector. -template -SmallVector firstFew(ValueRange vec, int64_t untilNum) { - SmallVector res; - int64_t size = vec.size(); - if (untilNum < 0) - untilNum += size; - // If untilNum<0... we get an empty vector, that is ok. - assert(untilNum < size && "out of bound"); - for (int64_t i = 0; i <= untilNum; ++i) - res.emplace_back(vec[i]); - return res; -} - -template -SmallVector firstFew(ArrayRef vec, int64_t untilNum) { - SmallVector res; - int64_t size = vec.size(); - if (untilNum < 0) - untilNum += size; - // If untilNum<0... we get an empty vector, that is ok. - assert(untilNum < size && "out of bound"); - for (int64_t i = 0; i <= untilNum; ++i) - res.emplace_back(vec[i]); - return res; -} - -template -SmallVector firstFew(SmallVectorImpl &vec, int64_t untilNum) { - SmallVector res; - int64_t size = vec.size(); - if (untilNum < 0) - untilNum += size; - // If untilNum<0... we get an empty vector, that is ok. - assert(untilNum < size && "out of bound"); - for (int64_t i = 0; i <= untilNum; ++i) - res.emplace_back(vec[i]); - return res; -} - -// From num, inclusive. Negative numbers count from the back of the vector. -template -SmallVector lastFew(ValueRange vec, int64_t fromNum) { - SmallVector res; - int64_t size = vec.size(); - if (fromNum < 0) - fromNum += size; - // If fromNum>= size... we get an empty vector, that is ok. - assert(fromNum >= 0 && "out of bound"); - for (int64_t i = fromNum; i < size; ++i) - res.emplace_back(vec[i]); - return res; -} - -template -SmallVector lastFew(ArrayRef vec, int64_t fromNum) { - SmallVector res; - int64_t size = vec.size(); - if (fromNum < 0) - fromNum += size; - // If fromNum>= size... we get an empty vector, that is ok. - assert(fromNum >= 0 && "out of bound"); - for (int64_t i = fromNum; i < size; ++i) - res.emplace_back(vec[i]); - return res; -} - -template -SmallVector lastFew(SmallVectorImpl &vec, int64_t fromNum) { - SmallVector res; - int64_t size = vec.size(); - if (fromNum < 0) - fromNum += size; - // If fromNum>= size... we get an empty vector, that is ok. - assert(fromNum >= 0 && "out of bound"); - for (int64_t i = fromNum; i < size; ++i) - res.emplace_back(vec[i]); - return res; -} -// end support - enum RLegacy { Latest, UpTo13 }; -// +//===----------------------------------------------------------------------===// +// Defaults + +// Defines the VectorBuilder's CombiningKind associated with a given Op. template VectorBuilder::CombiningKind getCombiningKind() { llvm_unreachable("illegal combination kind"); } -// Identity values -template <> -Value getIdentityValue( - ConversionPatternRewriter &rewriter, Location loc, Type type) { - MathBuilder createMath(rewriter, loc); - return createMath.negativeInf(type); +// Defines if the OP requires a divide by mean; false by default. +template +bool divideByMean() { + return false; } -template <> -VectorBuilder::CombiningKind getCombiningKind() { - return VectorBuilder::CombiningKind::MAX; -} +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceProdOp +//===----------------------------------------------------------------------===// template <> -Value getIdentityValue( +Value getIdentityValue( ConversionPatternRewriter &rewriter, Location loc, Type type) { MathBuilder createMath(rewriter, loc); - return createMath.negativeInf(type); + return createMath.constant(type, 1); } - template <> -VectorBuilder::CombiningKind getCombiningKind() { - return VectorBuilder::CombiningKind::MAX; +VectorBuilder::CombiningKind getCombiningKind() { + return VectorBuilder::CombiningKind::MUL; } - template <> -Value getIdentityValue( - ConversionPatternRewriter &rewriter, Location loc, Type type) { - MathBuilder createMath(rewriter, loc); - return createMath.positiveInf(type); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::MulGop, 1}}; } template <> -VectorBuilder::CombiningKind getCombiningKind() { - return VectorBuilder::CombiningKind::MIN; -} +struct ScalarOp { + using FOp = arith::MulFOp; + using IOp = arith::MulIOp; +}; template <> -Value getIdentityValue( +Value getIdentityValue( ConversionPatternRewriter &rewriter, Location loc, Type type) { - MathBuilder createMath(rewriter, loc); - return createMath.positiveInf(type); + return getIdentityValue(rewriter, loc, type); } - template <> -VectorBuilder::CombiningKind getCombiningKind() { - return VectorBuilder::CombiningKind::MIN; +VectorBuilder::CombiningKind getCombiningKind() { + return getCombiningKind(); } - template <> -Value getIdentityValue( - ConversionPatternRewriter &rewriter, Location loc, Type type) { - MathBuilder createMath(rewriter, loc); - return createMath.constant(type, 1); +GenOpMix getGenOpMix(Type t, Operation *op) { + return getGenOpMix(t, op); } - template <> -VectorBuilder::CombiningKind getCombiningKind() { - return VectorBuilder::CombiningKind::MUL; -} +struct ScalarOp { + using FOp = arith::MulFOp; + using IOp = arith::MulIOp; +}; + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceSumOp +//===----------------------------------------------------------------------===// template <> -Value getIdentityValue( +Value getIdentityValue( ConversionPatternRewriter &rewriter, Location loc, Type type) { MathBuilder createMath(rewriter, loc); - return createMath.constant(type, 1); + return createMath.constant(type, 0); } - template <> -VectorBuilder::CombiningKind getCombiningKind() { - return VectorBuilder::CombiningKind::MUL; +VectorBuilder::CombiningKind getCombiningKind() { + return VectorBuilder::CombiningKind::ADD; } +template <> +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}}; +} +template <> +struct ScalarOp { + using FOp = arith::AddFOp; + using IOp = arith::AddIOp; +}; template <> Value getIdentityValue( ConversionPatternRewriter &rewriter, Location loc, Type type) { - MathBuilder createMath(rewriter, loc); - return createMath.constant(type, 0); + return getIdentityValue(rewriter, loc, type); } - template <> VectorBuilder::CombiningKind getCombiningKind() { - return VectorBuilder::CombiningKind::ADD; + return getCombiningKind(); } - template <> -Value getIdentityValue( - ConversionPatternRewriter &rewriter, Location loc, Type type) { - MathBuilder createMath(rewriter, loc); - return createMath.constant(type, 0); +GenOpMix getGenOpMix(Type t, Operation *op) { + return getGenOpMix(t, op); } - template <> -VectorBuilder::CombiningKind getCombiningKind() { - return VectorBuilder::CombiningKind::ADD; -} +struct ScalarOp { + using FOp = arith::AddFOp; + using IOp = arith::AddIOp; +}; + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceMeanOp +//===----------------------------------------------------------------------===// template <> Value getIdentityValue( @@ -218,57 +139,43 @@ Value getIdentityValue( MathBuilder createMath(rewriter, loc); return createMath.constant(type, 0); } - template <> VectorBuilder::CombiningKind getCombiningKind() { return VectorBuilder::CombiningKind::ADD; } - template <> -Value getIdentityValue( - ConversionPatternRewriter &rewriter, Location loc, Type type) { - MathBuilder createMath(rewriter, loc); - return createMath.constant(type, 0); +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::ArithmeticGop, 1}}; } - template <> -VectorBuilder::CombiningKind getCombiningKind() { - return VectorBuilder::CombiningKind::ADD; +bool divideByMean() { + return true; } - -// Scalar ops template <> -struct ScalarOp { - using FOp = arith::MulFOp; - using IOp = arith::MulIOp; -}; - -template <> -struct ScalarOp { - using FOp = arith::MulFOp; - using IOp = arith::MulIOp; -}; - -template <> -struct ScalarOp { +struct ScalarOp { using FOp = arith::AddFOp; using IOp = arith::AddIOp; }; template <> -struct ScalarOp { - using FOp = arith::AddFOp; - using IOp = arith::AddIOp; -}; - +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return getIdentityValue(rewriter, loc, type); +} template <> -struct ScalarOp { - using FOp = arith::AddFOp; - using IOp = arith::AddIOp; -}; - +VectorBuilder::CombiningKind getCombiningKind() { + return getCombiningKind(); +} template <> -struct ScalarOp { +GenOpMix getGenOpMix(Type t, Operation *op) { + return getGenOpMix(t, op); +} +template <> +bool divideByMean() { + return divideByMean(); +} +template <> +struct ScalarOp { using FOp = arith::AddFOp; using IOp = arith::AddIOp; }; @@ -276,16 +183,21 @@ struct ScalarOp { //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReduceMaxOp //===----------------------------------------------------------------------===// + template <> -Value emitScalarOpFor(ConversionPatternRewriter &rewriter, - Location loc, Operation *op, Type elementType, - ArrayRef scalarOperands) { +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { MathBuilder createMath(rewriter, loc); - Value lhs = scalarOperands[0]; - Value rhs = scalarOperands[1]; - return createMath.max(lhs, rhs); + return createMath.negativeInf(type); +} +template <> +VectorBuilder::CombiningKind getCombiningKind() { + return VectorBuilder::CombiningKind::MAX; +} +template <> +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::MinMaxGop, 1}}; } - template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, @@ -296,11 +208,46 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, return createMath.max(lhs, rhs); } +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return getIdentityValue(rewriter, loc, type); +} +template <> +VectorBuilder::CombiningKind getCombiningKind() { + return getCombiningKind(); +} +template <> +GenOpMix getGenOpMix(Type t, Operation *op) { + return getGenOpMix(t, op); +} +template <> +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { + return emitScalarOpFor( + rewriter, loc, op, elementType, scalarOperands); +} //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReduceMinOp //===----------------------------------------------------------------------===// + template <> -Value emitScalarOpFor(ConversionPatternRewriter &rewriter, +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + MathBuilder createMath(rewriter, loc); + return createMath.positiveInf(type); +} +template <> +VectorBuilder::CombiningKind getCombiningKind() { + return VectorBuilder::CombiningKind::MIN; +} +template <> +GenOpMix getGenOpMix(Type t, Operation *op) { + return {{GenericOps::MinMaxGop, 1}}; +} +template <> +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { MathBuilder createMath(rewriter, loc); @@ -310,32 +257,206 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, } template <> -Value emitScalarOpFor(ConversionPatternRewriter &rewriter, +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return getIdentityValue(rewriter, loc, type); +} +template <> +VectorBuilder::CombiningKind getCombiningKind() { + return getCombiningKind(); +} +template <> +GenOpMix getGenOpMix(Type t, Operation *op) { + return getGenOpMix(t, op); +} +template <> +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { - MathBuilder createMath(rewriter, loc); - Value lhs = scalarOperands[0]; - Value rhs = scalarOperands[1]; - return createMath.min(lhs, rhs); + return emitScalarOpFor( + rewriter, loc, op, elementType, scalarOperands); +} + +//===----------------------------------------------------------------------===// + +using MDBuilder = + MultiDialectBuilder; + +//===----------------------------------------------------------------------===// +// Helper function to perform reduction when an entire tensor is reduced to a +// single value. Support the reduction for up to 2 operations at once. If only +// one is needed, then pass ONNXNoneOp in the second slot. +// Return true if we can optimize the reduction, false otherwise. + +// TODO: alexe add support for parallel +// TODO: alexe see if the new simd infrastructure can be used. +template +bool emitFullSIMDReductionFor(ConversionPatternRewriter &rewriter, Location loc, + Operation *op, Value input, Value &alloc1, Value &alloc2) { + // Create scope. + IndexExprScope scope(&rewriter, loc); + MDBuilder create(rewriter, loc); + // Get info. + MemRefType inputType = mlir::cast(input.getType()); + Type elementType = inputType.getElementType(); + int64_t inputRank = inputType.getRank(); + DimsExpr inputDims, flatInputDims; + create.krnlIE.getShapeAsSymbols(input, inputDims); + // Flatten entirely the input memref. + Value flatInput = create.mem.reshapeToFlatInnermost( + input, inputDims, flatInputDims, inputRank); + + // Has one or 2 reductions? + bool hasTwoRed = true; + if constexpr (std::is_same::value) + hasTwoRed = false; + + // Study SIMD. Assume here that since SIMD is determined by the input type + // (which is expected to be the same as the output scalar value), both + // reduction will have the same archVL. + GenOpMix mix = getGenOpMix(elementType, op); + if (hasTwoRed) { + GenOpMix mix2 = getGenOpMix(elementType, op); + mix = computeGenOpMixUnion(mix, mix2); + } + int64_t collapsedInnermostLoops = inputRank; + int64_t simdLoopStaticTripCount; + bool simdOnly, canOverCompute = false; + int64_t totVL = + computeSuitableUnrollFactor(inputType, collapsedInnermostLoops, mix, + canOverCompute, simdLoopStaticTripCount, simdOnly); + // Current simdized loop only support SIMD only scheme. + if (!simdOnly) { + totVL = capVLForSimdOnly(inputType, totVL, simdLoopStaticTripCount); + } + if (totVL <= 1) + return false; // TODO alexe: consider staying here with VL=1 + IndexExpr VLIndexExpr = LitIE(totVL); + + // Compute type of small temporary reduction vector. + MemRefType outputType = MemRefType::get({}, elementType); + MemRefType redType = MemRefType::get({totVL}, elementType); + VectorType vecType = VectorType::get({totVL}, elementType); + + // Initialize first reduction. + Value zero = create.math.constantIndex(0); + /*output*/ alloc1 = create.mem.alloc(outputType); + Value redAlloc1 = create.mem.alignedAlloc(redType); + Value identity1 = getIdentityValue( + rewriter, create.getLoc(), elementType); + Value initVec1 = create.vec.splat(vecType, identity1); + create.vec.store(initVec1, redAlloc1, {zero}); + // Init second reduction. + alloc2 = nullptr; + Value redAlloc2 = nullptr; + if (hasTwoRed) { + /*output*/ alloc2 = create.mem.alloc(outputType); + redAlloc2 = create.mem.alignedAlloc(redType); + Value identity2 = getIdentityValue( + rewriter, create.getLoc(), elementType); + Value initVec2 = create.vec.splat(vecType, identity2); + create.vec.store(initVec2, redAlloc2, {zero}); + } + + // Loop over SIMD values. + ValueRange loopDef = create.krnl.defineLoops(1); + ValueRange blockedLoopDef = create.krnl.block(loopDef[0], totVL); + create.krnl.iterate(loopDef, {blockedLoopDef[0]}, {zero}, + {flatInputDims[0].getValue()}, [&](KrnlBuilder &ck, ValueRange loopInd) { + MDBuilder create(ck); + // Input values, loaded as a vector. + SmallVector inAccessVals; + inAccessVals.emplace_back(loopInd[0]); + Value inputVec = create.vec.load(vecType, flatInput, inAccessVals); + // Process first reduction. + Value redVec1 = create.vec.load(vecType, redAlloc1, {zero}); + Value accumulatedVec1 = emitScalarOpFor( + rewriter, create.getLoc(), op, vecType, {redVec1, inputVec}); + create.vec.store(accumulatedVec1, redAlloc1, {zero}); + // Process second reduction. + if (hasTwoRed) { + Value redVec2 = create.vec.load(vecType, redAlloc2, {zero}); + Value accumulatedVec2 = emitScalarOpFor( + rewriter, create.getLoc(), op, vecType, {redVec2, inputVec}); + create.vec.store(accumulatedVec2, redAlloc2, {zero}); + } + }); + + // First reduction horizontal sum. + Value reductionVec1 = create.vec.load(vecType, redAlloc1, {zero}); + Value res1 = + create.vec.reduction(getCombiningKind(), reductionVec1); + // Second reduction horizontal sum. + Value res2 = nullptr; + if (hasTwoRed) { + Value reductionVec2 = create.vec.load(vecType, redAlloc2, {zero}); + res2 = create.vec.reduction( + getCombiningKind(), reductionVec2); + } + + // Handle mean if any. + Value divisorForMean = nullptr; + if (divideByMean() || divideByMean()) { + // Compute the divisor that is the number of elements participated in + // reduction, i.e., 'divisor = size of input / size of output, where output + // size == 1'. + divisorForMean = create.math.cast(elementType, flatInputDims[0].getValue()); + } + if (divideByMean()) + res1 = create.math.div(res1, divisorForMean); + if (hasTwoRed && divideByMean()) + res2 = create.math.div(res2, divisorForMean); + + // Save result. + create.affineKMem.store(res1, alloc1, {}); + if (hasTwoRed) + create.affineKMem.store(res2, alloc2, {}); + + if (hasTwoRed) + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "fused reduction to a scalar"); + else + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "reduction to a scalar"); + + return true; +} + +void emitMinMaxReductionToScalar(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Value input, Value &minAlloc, Value &maxAlloc, + bool enableSIMD, bool enableParallel) { + // Try optimized path first. + if (enableSIMD && emitFullSIMDReductionFor( + rewriter, loc, op, input, minAlloc, maxAlloc)) + return; + // Could not optimize the pattern, generate default path. + MultiDialectBuilder create(rewriter, loc); + Type elementType = mlir::cast(input.getType()).getElementType(); + MemRefType outputType = MemRefType::get({}, elementType); + Value none = create.onnx.none(); + // Generate reductions. + minAlloc = create.onnx.toMemref( + create.onnx.reduceMin(outputType, input, none, false)); + maxAlloc = create.onnx.toMemref( + create.onnx.reduceMax(outputType, input, none, false)); } -// This duplicated code can be eliminated with if constexpr in c++ 17 -// Or onnx uses input for axes for all ops +//===----------------------------------------------------------------------===// +// Generic reduction code (for current and legacy using "if constexpr". +// Function use SIMD if all reductions occur consecutively in the innermost +// loops. + template struct ONNXReductionOpLowering : public OpConversionPattern { using OpAdaptor = typename ONNXReductionOp::Adaptor; bool enableSIMD = false; - bool computeMean = false; bool enableParallel = false; - using MDBuilder = - MultiDialectBuilder; - ONNXReductionOpLowering(TypeConverter &typeConverter, MLIRContext *ctx, - bool enableSIMD, bool enableParallel, bool computeMean = false) + bool enableSIMD, bool enableParallel) : OpConversionPattern(typeConverter, ctx), - enableSIMD(enableSIMD), computeMean(computeMean) { + enableSIMD(enableSIMD) { this->enableParallel = enableParallel && OnnxToKrnlLoweringConfiguration::enableSpecificParallelOps.isEnabled( @@ -482,8 +603,8 @@ struct ONNXReductionOpLowering : public OpConversionPattern { bool hasHorizontalSimdSupport = false; bool parallelSimd = false; int64_t innermostLoopCollapse = 0; - int64_t VL = 0; - int64_t estimatedSimdLoopTripCount = 0; + int64_t totVL = 1; + int64_t simdLoopStaticTripCount = 0; // With dynamic axes, use this Value maskVal = nullptr; @@ -522,36 +643,40 @@ struct ONNXReductionOpLowering : public OpConversionPattern { if (horizontalSimd || parallelSimd) { assert(!(horizontalSimd && parallelSimd) && "expected at most horizontal or parallel SIMD"); - VectorMachineSupport *vms = - VectorMachineSupport::getGlobalVectorMachineSupport(); DimsExpr inputDims; create.krnlIE.getShapeAsSymbols(input, inputDims); - int64_t unroll = 4; if (horizontalSimd) { #if !DEBUG_FORCE_SHUFFLE_REDUCTION VectorBuilder::CombiningKind kind = getCombiningKind(); hasHorizontalSimdSupport = - supportedHorizontalSIMDOp(vms, kind, elementOutType); + supportedHorizontalSIMDOp(kind, elementOutType); #endif - if (!hasHorizontalSimdSupport) { - // Does not have SIMD horizontal support, so use a scheme that - // unroll the innermost non-simd loop by VL. Because trip counts - // of such loops could be small (e.g. GPT2 = 8), we don't want a - // large VL here. - unroll = 1; - } } - LLVM_DEBUG(llvm::dbgs() - << " SIMD: study with init unroll " << unroll << "\n"); - VL = create.vec.computeSuitableUnrollFactor(vms, memRefInType, - inputDims, innermostLoopCollapse, unroll, /*canPad*/ false, - estimatedSimdLoopTripCount); + // Currently only vectorize loops whose SIMD dimension is a multiple + // of the natural SIMD width. Aka, we don't deal with SIMD of partial + // vectors. + GenOpMix mix = getGenOpMix(elementOutType, op); + bool simdOnly, canOverCompute = false; + totVL = + computeSuitableUnrollFactor(memRefInType, innermostLoopCollapse, + mix, canOverCompute, simdLoopStaticTripCount, simdOnly); + if (!hasHorizontalSimdSupport) { + // When we don't have horizontal SIMD support, we use a code gen + // scheme that relies on unrolling. So we don't want any unrollVL + // here. Some benchmarks have small trip counts (e.g. GPT2: 8). + totVL = capVLForMaxUnroll(memRefInType, totVL, 1); + } + // Current code gen scheme only support SIMD only scheme. + if (!simdOnly) { + totVL = + capVLForSimdOnly(memRefInType, totVL, simdLoopStaticTripCount); + } LLVM_DEBUG(llvm::dbgs() << " SIMD: " << innermostLoopCollapse - << " loops, VL " << VL << "\n"); - if (!VL) { + << " loops, totVL " << totVL << "\n"); + if (totVL <= 1) { horizontalSimd = parallelSimd = false; - LLVM_DEBUG(llvm::dbgs() << " SIMD: no good VL\n"); + LLVM_DEBUG(llvm::dbgs() << " SIMD: no good totVL\n"); } } } @@ -630,8 +755,8 @@ struct ONNXReductionOpLowering : public OpConversionPattern { } } } - LLVM_DEBUG(llvm::dbgs() << " SIMD " << (VL ? "" : "im") - << "possible with vector length " << VL << "\n"); + LLVM_DEBUG(llvm::dbgs() << " SIMD " << (totVL > 1 ? "" : "im") + << "possible with totVL " << totVL << "\n"); ////////////////////////////////////////////////////////////////////// // Insert an allocation and deallocation for the result of this operation. @@ -662,7 +787,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { // Used if compute mean Value divisorForMean = nullptr; - if (computeMean) { + if (divideByMean()) { // Compute the divisor that is the number of elements participated in // reduction, i.e., 'divisor = size of input / size of output'. IndexExprScope scope(create.krnl); @@ -683,16 +808,16 @@ struct ONNXReductionOpLowering : public OpConversionPattern { if (horizontalSimd) { if (hasHorizontalSimdSupport) { genHorizontalSimdReduction(rewriter, create, op, elementOutType, input, - alloc, inRank, outRank, VL, innermostLoopCollapse, isKeepdims, + alloc, inRank, outRank, totVL, innermostLoopCollapse, isKeepdims, divisorForMean, enableParallel); - onnxToKrnlSimdReport(op, /*successful*/ true, VL, - estimatedSimdLoopTripCount, "horizontal"); + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "horizontal"); } else { genShuffleHorizontalSimdReduction(rewriter, create, op, elementOutType, - input, alloc, inRank, outRank, VL, innermostLoopCollapse, + input, alloc, inRank, outRank, totVL, innermostLoopCollapse, isKeepdims, divisorForMean, enableParallel); - onnxToKrnlSimdReport(op, /*successful*/ true, VL, - estimatedSimdLoopTripCount, "shuffle-horizontal"); + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "shuffle-horizontal"); } } else { genScalarReduction(rewriter, create, op, elementOutType, input, alloc, @@ -704,7 +829,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { else msg = "unsupported"; onnxToKrnlSimdReport( - op, /*successful*/ false, /*vl*/ 0, estimatedSimdLoopTripCount, msg); + op, /*successful*/ false, /*vl*/ 0, simdLoopStaticTripCount, msg); } rewriter.replaceOp(op, alloc); return success(); @@ -715,6 +840,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { Value alloc, int64_t inRank, int64_t outRank, bool dynamicAxes, Value maskVal, std::map &outInDimMap, Value divisorForMean, bool enableParallel) const { + LLVM_DEBUG(llvm::dbgs() << "gen scalar reduction\n"); ////////////////////////////////////////////////////////////////////// // There are two required and one optional Krnl loops: // - One to initialize the result memref, @@ -765,7 +891,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { }); // 3. Define an Krnl loop to compute mean (optional). - if (computeMean) { + if (divideByMean()) { // Compute mean ValueRange loop3Def = create.krnl.defineLoops(outRank); SmallVector lbs3(outRank, LiteralIndexExpr(0)); @@ -793,23 +919,40 @@ struct ONNXReductionOpLowering : public OpConversionPattern { } } - bool supportedHorizontalSIMDOp(VectorMachineSupport *vms, + bool supportedHorizontalSIMDOp( VectorBuilder::CombiningKind getCombiningKind, Type elementType) const { int64_t len; switch (getCombiningKind) { case VectorBuilder::CombiningKind::ADD: - len = vms->getVectorLength(GenericOps::SumAcrossGop, elementType); + len = VectorMachineSupport::getArchVectorLength( + GenericOps::SumAcrossGop, elementType); break; case VectorBuilder::CombiningKind::MIN: case VectorBuilder::CombiningKind::MAX: - len = vms->getVectorLength(GenericOps::SumAcrossGop, elementType); + len = VectorMachineSupport::getArchVectorLength( + GenericOps::SumAcrossGop, elementType); break; default: - len = 0; + len = 1; } - return len != 0; + return len != 1; } + // Generate a single reduction, eventually using a horizontal reduction + // (which, if the hardware supports it, will be one instruction; otherwise it + // will be simulated by several operations). + // + // flatInput has been flattened from [N][M][R1][R2] to [N][M][R1*R2], where + // the SIMD reduction is done along the last dim. By definition of what we + // support here, R1*R2 mod VL = 0, namely the reduction dimension is a + // multiple of VL (no partial SIMD). + // + // tmpAlloc has been flattened (if keepDim is true) to [N][M]. + // + // outLoopInd defines which [n][m] is to be used to load the inputs to be + // reduced (flatInput[n][m][*]) and where the reduction is to be saved + // (flatAlloc[n][m]). + void genOneHorizontalSimdReduction(ConversionPatternRewriter &rewriter, MDBuilder &create, Operation *op, Type elementType, VectorType vecType, Value tmpAlloca, Value flatInput, Value flatAlloc, Value initVec, @@ -839,19 +982,22 @@ struct ONNXReductionOpLowering : public OpConversionPattern { Value accumulatedVal = create.vec.reduction(getCombiningKind(), reductionVec); // other operation... - if (computeMean) { + if (divideByMean()) { accumulatedVal = create.math.div(accumulatedVal, divisorForMean); } // Store tmp into result. create.krnl.store(accumulatedVal, flatAlloc, outLoopInd); } + // We assume here that the hardware has an efficient SIMD horizontal + // operation, so we simply generate one horizontal SIMD reduction for each + // reductions that needs to be performed. void genHorizontalSimdReduction(ConversionPatternRewriter &rewriter, MDBuilder &create, Operation *op, Type elementType, Value input, Value alloc, int64_t inRank, int64_t outRank, int64_t VL, int64_t collapsedInnermostLoops, bool isKeepDims, Value divisorForMean, bool enableParallel) const { - + LLVM_DEBUG(llvm::dbgs() << "gen horizontal simd reduction\n"); assert(VL > 1 && "expected simd here"); VectorType vecType = VectorType::get({VL}, elementType); // Flatten the input: in[N][M][Red1][Red2] -> in[N][M][Red1*Red2] @@ -908,6 +1054,22 @@ struct ONNXReductionOpLowering : public OpConversionPattern { }); } + // We perform here VL Simd Reductions at once. We are guaranteed that there + // are VL reductions to be performed. The algorithm works in 2 steps. + // + // In the first step, we perform the SIMD reductions of VL distinct reductions + // using the "emitScalarOp" associated with that operation. At the end of this + // step, we have VL distinct partial reductions, where each of the VL vector + // register have a partial reduction in each of their own VL SIMD slots. + // + // In the second step, we reduce each VL vectors of VL partial values into one + // vector of VL fully-reduced values. We use shuffle patterns to generate + // efficient code where each of the temporary vectors always contain VL + // values. This is implemented by the create.vec.multiReduction operation. + // + // Finally, the VL full reductions are stored as a vector operation in the + // flatAlloc[m][n+0...+VL-1] output. + void genVlHorizontalSimdReduction(ConversionPatternRewriter &rewriter, MDBuilder &create, Operation *op, Type elementType, VectorType vecType, Value tmpBlockedAlloca, Value flatInput, Value flatAlloc, Value initVec, @@ -919,7 +1081,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { create.vec.store( initVec, tmpBlockedAlloca, {create.math.constantIndex(i), zero}); } - // Blocked Simd loop. + // First step: blocked simd loop. ValueRange simdLoopDef = create.krnl.defineLoops(1); ValueRange blockedSimdLoopDef = create.krnl.block(simdLoopDef[0], VL); create.krnl.iterate(simdLoopDef, {blockedSimdLoopDef[0]}, {zero}, {simdUB}, @@ -948,6 +1110,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { create.vec.store(accumulatedVec, tmpBlockedAlloca, {tmpInd, zero}); } /* intra block output loop */ }); /* blocked simd loop */ + // Step 2 // Load all temp vectors. SmallVector redIn, redOut; for (int64_t i = 0; i < VL; ++i) { @@ -965,7 +1128,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { assert(redOut.size() == 1 && "expected only one val"); Value accumulatedVal = redOut[0]; // Perform the mean computation if required. - if (computeMean) { + if (divideByMean()) { Value divisorForMeanVec = create.vec.splat(vecType, divisorForMean); accumulatedVal = create.math.div(accumulatedVal, divisorForMeanVec); } @@ -974,13 +1137,22 @@ struct ONNXReductionOpLowering : public OpConversionPattern { } // Solution when there is no horizontal SIMD op support and that shuffle ops - // are needed. + // are needed. Assuming a (flattened) output reduction tensor of [N][M], this + // algorithm will block the inter dimension of the output tensor by VL. For + // each block of VL values to be reduced, we use the efficient functions that + // computes them using shuffles (genVlHorizontalSimdReduction). For the last + // block (if any) that has fewer than VL remaining reductions to be performed, + // we simply perform r 1 && "expected simd here"); IndexExpr VLIndexExpr = LiteralIndexExpr(VL); VectorType vecType = VectorType::get({VL}, elementType); @@ -1091,17 +1263,15 @@ void populateLoweringONNXReductionOpPattern(RewritePatternSet &patterns, bool enableParallel) { patterns.insert< ONNXReductionOpLowering, + ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXReductionOpLowering, + ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXReductionOpLowering>( typeConverter, ctx, enableSIMD, enableParallel); - patterns.insert< - ONNXReductionOpLowering, - ONNXReductionOpLowering>( - typeConverter, ctx, enableSIMD, enableParallel, /*computeMean=*/true); } } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp index 1af693b95a..7ff02d5849 100644 --- a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp @@ -4,7 +4,7 @@ //===----------- Normalization.cpp - Lowering Normalization Ops -----------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -581,12 +581,12 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { llvm_unreachable("unexpected case"); } - bool isSimdizable(MDBuilder &create, OP_TYPE lnOp, ADAPTOR_TYPE adaptor, - SHAPE_HELPER_TYPE &shapeHelper, int64_t &VL, + bool isSimdizable(OP_TYPE lnOp, ADAPTOR_TYPE adaptor, + SHAPE_HELPER_TYPE &shapeHelper, int64_t &totVL, BroadcastKind &scaleBroadcastKind, BroadcastKind &biasBroadcastKind, IndexExpr &scaleModFactor, IndexExpr &biasModFactor) const { - VL = 0; + totVL = 1; Operation *op = lnOp.getOperation(); if (!enableSIMD) { onnxToKrnlSimdReport( @@ -602,8 +602,6 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { int64_t axis = getAxisInRange(lnOp.getAxis(), XRank); int64_t lowRank = XRank - axis; // Detect if we can use SIMD based on inout/X output/Y shape. - VectorMachineSupport *vms = - VectorMachineSupport::getGlobalVectorMachineSupport(); // Implementation relies into splitting the input X into a 2D vector, with // outer dim is batches, and inner dims is where the mean/stddev is @@ -617,21 +615,14 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { return false; } - // Do not want to disable SIMD for lack of sum across support at this - // stage. Type elementType = XMemRefType.getElementType(); - // - // if (vms->getVectorLength(GenericOps::SumAcrossGop, elementType) <= 0) { - // LLVM_DEBUG(llvm::dbgs() << " SIMD: unsupported sum across, fail\n"); - // return false; - // } - + // TODO: Use old scheme here, maybe update to new scheme. int64_t simdLoopStaticTripCount; - VL = create.vec.computeSuitableUnrollFactor(vms, XMemRefType, XDims, - lowRank, 4, /*canPad*/ false, simdLoopStaticTripCount); + totVL = computeSuitableUnrollFactor(XMemRefType, lowRank, 4, + /*canOverCompute*/ false, simdLoopStaticTripCount); LLVM_DEBUG(llvm::dbgs() << " SIMD: LayerNormalization " << simdLoopStaticTripCount - << " loops, VL " << VL << "\n";); - if (VL == 0) { + << " loops, totVL " << totVL << "\n";); + if (totVL <= 1) { onnxToKrnlSimdReport(op, /*successful*/ false, 0, simdLoopStaticTripCount, "no simd because could not find beneficial VL"); return false; @@ -658,7 +649,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { } } onnxToKrnlSimdReport( - op, /*successful*/ true, VL, simdLoopStaticTripCount, "successful"); + op, /*successful*/ true, totVL, simdLoopStaticTripCount, "successful"); return true; } @@ -673,15 +664,16 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { SHAPE_HELPER_TYPE shapeHelper(op, operands, &create.krnlIE); shapeHelper.computeShapeAndAssertOnFailure(); - int64_t VL; + int64_t totVL; BroadcastKind scaleBroadcastKind, biasBroadcastKind; IndexExpr scaleModFactor, biasModFactor; - bool isSIMD = isSimdizable(create, lnOp, adaptor, shapeHelper, VL, + bool isSIMD = isSimdizable(lnOp, adaptor, shapeHelper, totVL, scaleBroadcastKind, biasBroadcastKind, scaleModFactor, biasModFactor); if (isSIMD) { - return generateSIMDCode(rewriter, loc, lnOp, adaptor, shapeHelper, 4, VL, - scaleBroadcastKind, biasBroadcastKind, scaleModFactor, biasModFactor); + return generateSIMDCode(rewriter, loc, lnOp, adaptor, shapeHelper, 4, + totVL, scaleBroadcastKind, biasBroadcastKind, scaleModFactor, + biasModFactor); } return generateGenericLayerNormOpONNXCode( rewriter, loc, lnOp, this->typeConverter); @@ -759,9 +751,9 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { /* temps [B][vec] */ Value redMemRef, Value redMemRef2, /* index expr param */ IndexExpr redDim, /* value params */ Value i, Value epsilon, - /* int params */ int64_t B, int64_t VL, BroadcastKind scaleBroadcastKind, - BroadcastKind biasBroadcastKind, IndexExpr scaleModFactor, - IndexExpr biasModFactor) const { + /* int params */ int64_t B, int64_t totVL, + BroadcastKind scaleBroadcastKind, BroadcastKind biasBroadcastKind, + IndexExpr scaleModFactor, IndexExpr biasModFactor) const { // Bool isTraditionalLayerNorm is true when computing traditional layer // norm, not the faster RMS version. bool isTraditionalLayerNorm = false; @@ -770,7 +762,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { // Vector type. Type elementType = mlir::cast(YMemRef.getType()).getElementType(); - VectorType vecType = VectorType::get({VL}, elementType); + VectorType vecType = VectorType::get({totVL}, elementType); // Init the two reductions. Value init = create.math.constant(elementType, 0.0); Value initVec = create.vec.splat(vecType, init); @@ -782,7 +774,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { }); // Perform reduction of entire vectors. IndexExpr izero = LiteralIndexExpr(0); - create.affineKMem.forIE(izero, redDim, VL, + create.affineKMem.forIE(izero, redDim, totVL, [&](onnx_mlir::AffineBuilderKrnlMem &ck, mlir::Value j) { MDBuilder create(ck); // load X, compute X**2, sum into reductions. @@ -836,7 +828,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { invStdDev[d] = create.math.div(oneFloat, stdDev); }); // Normalize of entire vectors. - create.affineKMem.forIE(izero, redDim, VL, + create.affineKMem.forIE(izero, redDim, totVL, [&](onnx_mlir::AffineBuilderKrnlMem &ck, mlir::Value j) { MDBuilder create(ck); // load X, compute X**2, sum into reductions. @@ -883,7 +875,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { LogicalResult generateSIMDCode(ConversionPatternRewriter &rewriter, Location loc, OP_TYPE lnOp, ADAPTOR_TYPE &adaptor, - SHAPE_HELPER_TYPE &shapeHelper, int64_t B, int64_t VL, + SHAPE_HELPER_TYPE &shapeHelper, int64_t B, int64_t totVL, BroadcastKind scaleBroadcastKind, BroadcastKind biasBroadcastKind, IndexExpr scaleModFactor, IndexExpr biasModFactor) const { @@ -946,7 +938,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { shapeHelper.getOutputDims(2), axis, invStdDevMemRef, invStdDevFlatMemRef); // Alloc mem for reductions (should be private if parallel) - MemRefType tmpRedType = MemRefType::get({B, VL}, elementType); + MemRefType tmpRedType = MemRefType::get({B, totVL}, elementType); // Iterate over 1st dim by block ValueRange loopDefs = create.krnl.defineLoops(1); IndexExpr zero = LiteralIndexExpr(0); @@ -992,7 +984,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { scaleFlatMemRef, biasFlatMemRef, YFlatMemRef, meanFlatMemRef, invStdDevFlatMemRef, tmpRedMemRef, tmpRedMemRef2, XFlatDims[1], blockLocalInd, epsilon, - 1, VL, scaleBroadcastKind, biasBroadcastKind, + 1, totVL, scaleBroadcastKind, biasBroadcastKind, scaleModFactor, biasModFactor); }); /* for inside blocked loop */ }, @@ -1003,7 +995,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { scaleFlatMemRef, biasFlatMemRef, YFlatMemRef, meanFlatMemRef, invStdDevFlatMemRef, tmpRedMemRef, tmpRedMemRef2, XFlatDims[1], blockedLoopIndices[0], epsilon, - B, VL, scaleBroadcastKind, biasBroadcastKind, + B, totVL, scaleBroadcastKind, biasBroadcastKind, scaleModFactor, biasModFactor); }); }); /* blocked loop */ diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index a8d4da59da..861f6bb4dc 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -4,7 +4,7 @@ //====----- ONNXToKrnlCommon.cpp - ONNX dialects to Krnl lowering ---------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -112,7 +112,8 @@ Value OnnxToKrnlBuilder::transpose(const Value input, bool isScalarValue(Value value) { ShapedType stype = mlir::dyn_cast(value.getType()); assert(stype && "expected shaped type"); - return stype.getRank() == 0; + return (stype.getRank() == 0) || + (stype.getRank() == 1 && stype.getShape()[0] == 1); } /// Check if all operands are scalar values at compile time. @@ -139,20 +140,6 @@ bool hasOneElement(Value value) { return true; } -// Same as above, but from the innermost dimensions up to innerDim. -bool hasOneElementInInnermostDims(Value value, int64_t innerDim) { - if (isScalarValue(value)) - return true; - ShapedType type = mlir::dyn_cast(value.getType()); - assert(type && "expected shaped type"); - mlir::ArrayRef shape = type.getShape(); - int64_t rank = type.getRank(); - for (int64_t i = rank - innerDim; i < rank; ++i) - if (shape[i] != 1) - return false; - return true; -} - /// Check if the value is a KrnlGlobalOp with a dense attribute of non-negative /// integer constants. bool indicesAreNonNegativeConstants(Value indices) { @@ -643,6 +630,200 @@ bool findSuitableParallelDimension(llvm::SmallVectorImpl &lb, return false; } +//===----------------------------------------------------------------------===// +// Support functions for simd. +//===----------------------------------------------------------------------===// + +// New style. +int64_t computeSuitableUnrollFactor(MemRefType memRefType, + int64_t collapsedInnermostLoops, GenOpMix &genOps, bool canOverCompute, + int64_t &simdLoopStaticTripCount, bool &simdOnly) { + // Default return values for no simd. + simdLoopStaticTripCount = 0; + simdOnly = false; + + // Analyze size of SIMD iterations. + int64_t staticSimdSize; + bool isStatic = MemRefBuilder::getStaticMemSize( + memRefType, staticSimdSize, -collapsedInnermostLoops); + + Type elementType = memRefType.getElementType(); + int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); + LLVM_DEBUG(llvm::dbgs() << " simd archVL is " << archVL << "\n"); + + // Element type does nt support SIMD. + if (archVL <= 1) { + LLVM_DEBUG(llvm::dbgs() << " simd disabled: no simd for this type\n"); + return 1; + } + if (isStatic && staticSimdSize < archVL) { + LLVM_DEBUG(llvm::dbgs() << " simd disabled: static trip count " + << staticSimdSize << " too short for a VL\n"); + return 1; + } + // Gather operation statics + int64_t vectorizedOpNum, scalarOpNum; + double avgVL = VectorMachineSupport::getAvgArchVectorLength( + genOps, elementType, vectorizedOpNum, scalarOpNum); + if (avgVL < 1.5) { + LLVM_DEBUG(llvm::dbgs() << " simd disabled: too few SIMD operations with " + << avgVL << " avg VL\n"); + return 1; + } + LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL << "\n"); + + // Define a target max unroll as a function of register pressure. + int64_t unrollVL; + int64_t vrNum = VectorMachineSupport::getArchVectorRegisterNum(); + if (vectorizedOpNum >= vrNum / 2) + unrollVL = 2; + else if (vectorizedOpNum >= vrNum / 4) + unrollVL = 4; + else + unrollVL = 8; + int64_t totVL = archVL * unrollVL; + // Refine unrolling factor so that it is suitable for short loops. + if (isStatic && (staticSimdSize < unrollVL * archVL)) { + int64_t newUnroll = floor((1.0 * staticSimdSize) / (1.0 * archVL)); + LLVM_DEBUG(llvm::dbgs() << " simd enable: size " << staticSimdSize + << " , archVL " << archVL << ", unroll " << unrollVL + << ", reduced to " << newUnroll << "\n"); + unrollVL = newUnroll; + totVL = archVL * unrollVL; + if (canOverCompute && staticSimdSize % totVL != 0) { + // Does not divide; since we can over compute, increase unrollVL by 1. + LLVM_DEBUG( + llvm::dbgs() << " simd enable: can over compute, boost unrollVL\n"); + ++unrollVL; + totVL = archVL * unrollVL; + } + // Size control: if no ILP (unrollVL==1) or little ILP (unrollVL==2) with a + // leftover scalar loop, don't bother. + if (unrollVL == 1) { + LLVM_DEBUG(llvm::dbgs() << " simd disable: too small unrollVL (1)\n"); + return 1; + } + if (!canOverCompute && unrollVL == 2 && staticSimdSize % totVL != 0) { + LLVM_DEBUG(llvm::dbgs() + << " simd disable: small unrollVL (2) with leftovers\n"); + return 1; + } + } + LLVM_DEBUG(llvm::dbgs() << " simd enable: unrollVL " << unrollVL << "\n"); + // Fill in the output values. Now that we have SIMD, simdLoopStaticTripCount + // is either the static simd size if the trip is not runtime, or -1 if its + // runtime. + simdLoopStaticTripCount = isStatic ? staticSimdSize : -1; + // Now that we have SIMD, we have SIMD only if the static component of the + // SIMD loop is positive and a multiple of VL. + simdOnly = (staticSimdSize > 1) && (staticSimdSize % totVL == 0); + LLVM_DEBUG(llvm::dbgs() << " simd enable: totVL " << totVL << ", simd-only " + << simdOnly << "\n"); + if (canOverCompute && !simdOnly) { + LLVM_DEBUG( + llvm::dbgs() << " simd enable: can over compute, force simdOnly\n"); + simdOnly = true; + } + return archVL * unrollVL; +} + +int64_t capVLForMaxUnroll( + MemRefType memRefType, int64_t totVL, int64_t maxUnrollVL) { + if (totVL == 1) + return 1; // Simd already disabled, nothing to cap. + Type elementType = memRefType.getElementType(); + int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); + int64_t unrollVL = totVL / archVL; + assert(archVL * unrollVL == totVL && "expected archVL to divide totVL"); + if (unrollVL > maxUnrollVL) { + LLVM_DEBUG(llvm::dbgs() << " simd enable: unrollVL " << unrollVL + << " capped at " << maxUnrollVL << "\n"); + unrollVL = maxUnrollVL; + } + return archVL * unrollVL; +} + +int64_t capVLForSimdOnly( + MemRefType memRefType, int64_t totVL, int64_t simdLoopStaticTripCount) { + if (totVL == 1) + return 1; // Simd already disabled, nothing to cap. + if (simdLoopStaticTripCount <= 1) { + // There is no static component to simd loop trip count. + LLVM_DEBUG(llvm::dbgs() << " simd disable: dyn trip count, no simdOnly\n"); + return 1; + } + int64_t archVL = + VectorMachineSupport::getArchVectorLength(memRefType.getElementType()); + int64_t unrollVL = totVL / archVL; + assert(archVL * unrollVL == totVL && "expected archVL to divide totVL"); + for (int64_t u = unrollVL; u > 0; --u) { + totVL = u * archVL; + if (simdLoopStaticTripCount % totVL == 0) { + // Success. + LLVM_DEBUG(llvm::dbgs() + << " simd enable: simd only with totVL " << totVL << "\n"); + return totVL; + } + } + // Did not find any unroll factor for which totVL divides static trip count. + LLVM_DEBUG(llvm::dbgs() << " simd disable: no simdONLY for trip count\n"); + return 1; +} + +// Old style. +int64_t computeSuitableUnrollFactor(MemRefType memRefType, + int64_t collapsedInnermostLoops, int64_t maxUnrollVL, bool canOverCompute, + int64_t &simdLoopStaticTripCount) { + assert(collapsedInnermostLoops > 0 && "expected at least one collapsed loop"); + assert(maxUnrollVL > 0 && "expected positive max simd unroll"); + simdLoopStaticTripCount = 0; // Initially assume no SIMD. + Type elementType = memRefType.getElementType(); + int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); + LLVM_DEBUG(llvm::dbgs() << " simd archVL is " << archVL << "\n"); + if (archVL <= 1) { + LLVM_DEBUG(llvm::dbgs() << " simd disabled: no simd\n"); + return 1; + } + int64_t staticSize; + bool isStaticSize = MemRefBuilder::getStaticMemSize( + memRefType, staticSize, -collapsedInnermostLoops); + if (isStaticSize && staticSize < archVL) { + LLVM_DEBUG(llvm::dbgs() << " simd disabled: trip count " << staticSize + << " too short for a archVL of " << archVL << "\n"); + return 1; + } + // Unless otherwise disabled, here is the estimated trip count. + if (canOverCompute && + collapsedInnermostLoops == (int64_t)memRefType.getRank()) { + // Fully collapsed and can add padding to be fine + simdLoopStaticTripCount = isStaticSize ? staticSize : -1; + return maxUnrollVL * archVL; + } + // We have a partially flattened operator. Since we do only simdize entire + // loops (i.e. we don't support scalar epilogues at this time), make sure + // the static size is a multiple of the VL. Get the VL of the store + // (output's element type). + if (staticSize % archVL != 0) { + LLVM_DEBUG(llvm::dbgs() + << " simd disabled: partial flattened dims " + << collapsedInnermostLoops << " with size " << staticSize + << " is not 0 mod archVL " << archVL << "\n"); + return 1; + } + // See if we can get a unroll factor. + for (int64_t u = maxUnrollVL; u > 0; --u) { + if (staticSize % (u * archVL) == 0) { + LLVM_DEBUG(llvm::dbgs() + << " partial flattened dims " << collapsedInnermostLoops + << " with size " << staticSize << " works with VL " << archVL + << " and unroll " << u << "\n"); + simdLoopStaticTripCount = isStaticSize ? staticSize : -1; + return u * archVL; + } + } + llvm_unreachable("should always find u==1 feasible"); +} + //===----------------------------------------------------------------------===// // Support functions for reporting. //===----------------------------------------------------------------------===// diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 8bfb84475d..485ece2370 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_TO_KRNL_H +#define ONNX_MLIR_ONNX_TO_KRNL_H #include @@ -92,9 +93,6 @@ bool hasAllScalarValues(mlir::ValueRange values); // HasOneElement returns true for scalars as well as tensors that contain only // one elements, such as 1xf32 or 1x1x1xf32. bool hasOneElement(mlir::Value value); -// Same as hasOneElement, but check only from the innerDims innermost -// dimensions. -bool hasOneElementInInnermostDims(mlir::Value value, int64_t innerDims); /// Check if the value is a KrnlGlobalOp with a dense attribute of non-negative /// integer constants. @@ -227,29 +225,50 @@ mlir::Value emitScalarOpFor(mlir::ConversionPatternRewriter &rewriter, // int. Thus we look at the type the first input argument, and not the output // elementType. mlir::Type actualElementType = - MathBuilder::elementTypeWithVector(scalarOperands[0].getType()); + MathBuilder::elementTypeOfScalarOrVector(scalarOperands[0]); // Perform int or float operation depending on the actual elementary type. if (mlir::isa(actualElementType)) { // Generate the integer code only if the scalar integer op is non-void // (unsupported) and non-int (supported by custom sequence of ops). if constexpr (!(std::is_same, NotSuportedScalarOp>::value) && - !(std::is_same, CustomScalarOp>::value)) + !(std::is_same, CustomScalarOp>::value)) { + llvm::SmallVector scalarsSplatted(scalarOperands); + MultiDialectBuilder create(rewriter, loc); + create.math.splatToMatch(scalarsSplatted); return rewriter.create>( - loc, elementType, scalarOperands, std::nullopt); + loc, elementType, scalarsSplatted, std::nullopt); + } llvm_unreachable("unsupported integer operation"); } else if (mlir::isa(actualElementType)) { // Generate the floating point code only if the scalar integer op is - // non-void (unsupported) and non-int (supported by custom sequence of ops). + // non-void (unsupported) and non-int (supported by custom sequence of + // ops). if constexpr (!(std::is_same, NotSuportedScalarOp>::value) && - !(std::is_same, CustomScalarOp>::value)) + !(std::is_same, CustomScalarOp>::value)) { + llvm::SmallVector scalarsSplatted(scalarOperands); + MultiDialectBuilder create(rewriter, loc); + create.math.splatToMatch(scalarsSplatted); return rewriter.create>( - loc, elementType, scalarOperands, std::nullopt); + loc, elementType, scalarsSplatted, std::nullopt); + } llvm_unreachable("unsupported float operation"); } else { llvm_unreachable("unsupported element type"); } } +// ============================================================================= +// Template for SIMD analysis + +// Default template for ops that do not support SIMD. For the ones that support +// SIMD, we must create an `getGenOpsMix` template that returns their +// corresponding mix of generic operations. + +template +GenOpMix getGenOpMix(mlir::Type elementType, mlir::Operation *op) { + return {{GenericOps::ScalarOnlyGop, 1}}; +} + //===----------------------------------------------------------------------===// // Type conversion from Onnx types to Krnl types: // - from Tensor type to the Standard dialect MemRef type @@ -277,8 +296,8 @@ class KrnlTypeConverter : public mlir::TypeConverter { llvm::all_of(call.getResultTypes(), f); } - // Return the default alignment value used when allocating a MemRef buffer for - // the given type. E.g. some special types for accelerators requires + // Return the default alignment value used when allocating a MemRef buffer + // for the given type. E.g. some special types for accelerators requires // 4K-aligned buffers. static int64_t getDefaultAllocAlignment(mlir::Type type); }; @@ -359,9 +378,11 @@ void populateLoweringONNXNonMaxSuppressionOpPattern( // `Quantization` directory methods: void populateLoweringONNXDynamicQuantizeLinearOpPattern( - mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); -void populateLoweringONNXQuantizeLinearOpPattern( - mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *, + bool enableSIMD, bool enableParallel); +void populateLoweringONNXQuantizeLinearOpPattern(mlir::RewritePatternSet &, + mlir::TypeConverter &, mlir::MLIRContext *, bool enableSIMD, + bool enableParallel); // `RNN` directory methods: void populateLoweringONNXGRUOpPattern( @@ -595,6 +616,59 @@ bool findSuitableParallelDimension(llvm::SmallVectorImpl &lb, llvm::SmallVectorImpl &ub, int64_t firstInclusiveDim, int64_t lastExclusiveDim, int64_t &parDim, int64_t minSize = 4); +//===----------------------------------------------------------------------===// +// Support functions for determining simd unrolling. +//===----------------------------------------------------------------------===// + +// Compute a suitable SIMD Vector length (which may be a multiple of the +// hardware vector length, up to maxUnrollVL times). If the dims are too +// small, return 1 (no suitable simd). The collapsedInnermostLoops parameter +// indicates how many inner dimensions of the memref are considered for +// vectorization. If all of them are considered and padding is possible (aka +// canOverCompute==true), then we can always generate SIMD code with the +// maxSIMD unroll factor. Otherwise, we must ensure that the cumulative static +// size (dynamic sizes are ignored here ) of the array is a multiple of the +// Vector Length associated with this type. If it is not, then no SIMD code +// gen is possible (return 1). If it is possible, return the largest SIMD +// unroll factor (starting at maxUnrollVL) that divide the cumulative static +// size of the memref being collapsed for SIMD. simdLoopStaticTripCount: +// provide an estimation of the SIMD loop trip count. If runtime, return -1; +// if cannot simdize, return 0; otherwise, return that literal. +int64_t computeSuitableUnrollFactor(mlir::MemRefType memRefType, + int64_t collapsedInnermostLoops, int64_t maxUnrollVL, bool canOverCompute, + int64_t &simdLoopStaticTripCount); + +// Compute a suitable SIMD Vector Length (totVL). If no SIMD is suitable, +// return totVL = 1. Type determine the archVL for the given memRefType. Then +// compute the average amount of SIMD operations given the mix of Generic +// Operations in that loop. If the element type does not support SIMD, or +// there are too few SIMD operations, or the innermost loop has too few +// (static) loop iterations, SIMD will be disabled (return totVL=1). +// Otherwise, the register pressure is then taken into account to determine a +// suitable additional unrolling (by multiple of VL) so as to suitably exploit +// the available SIMD hardware. +// +// In this call, we assume that code gen can handle SIMD loops with trip count +// that are not known to be a multiple of VL. The simdOnly boolean flag will +// be set to true if all loop iterations can be handled using SIMD code with +// totVL. In other words, simdOnly is set to true if we can guarantee that +// there is no scalar loop for the leftovers not handled by the simd loop. +// +// Now some SIMD scheme may allow to write past the last original loop +// iterations; in this case we may ignore the simdOnly flag . +int64_t computeSuitableUnrollFactor(mlir::MemRefType memRefType, + int64_t collapsedInnermostLoops, GenOpMix &GenOps, bool canOverCompute, + int64_t &simdLoopStaticTripCount, bool &simdOnly); +// Cap totVL so that it is at most maxUnrollVL * archVL. +int64_t capVLForMaxUnroll( + mlir::MemRefType memRefType, int64_t totVL, int64_t maxUnrollVL); +// Enabling a simdOnly code generation scheme by capping totVL so that it +// divides simdLoopStaticTripCount. When not possible (either because +// there is no totVL that divides simdLoopStaticTripCount or trip count is +// runtime only), then disable SIMD by returning totVL = 1. +int64_t capVLForSimdOnly(mlir::MemRefType memRefType, int64_t totVL, + int64_t simdLoopStaticTripCount); + //===----------------------------------------------------------------------===// // Support functions for reporting. //===----------------------------------------------------------------------===// @@ -622,10 +696,9 @@ void onnxToKrnlParallelReport(mlir::Operation *op, bool successful, // the ONNX operation parallelized. // // Loop level: -1: none; 0: outermost; 1: next to outermost... -// Parallel loop trip count; 0: none; -1: runtime only; >0: min number known at -// compile time. -// Comment: explanation of how parallelism was achieved / or failed. Comments -// cannot have ',' in them. +// Parallel loop trip count; 0: none; -1: runtime only; >0: min number known +// at compile time. Comment: explanation of how parallelism was achieved / or +// failed. Comments cannot have ',' in them. inline void onnxToKrnlParallelReport(mlir::Operation *op, bool successful = false, int64_t loopLevel = -1, int64_t parallelLoopTripCount = 0, const std::string &comment = "") { @@ -656,9 +729,9 @@ inline void onnxToKrnlParallelReport(mlir::Operation *op, bool successful, // compile time. // Comment: explanation of how SIMD was achieved / or failed. Comments cannot // have ',' in them. Use the following comment templates. If SIMD is not -// supported, comments should be "unsupported". If SIMD is supported but fails, -// comment should be "no simd [in ] because ." When simd -// succeeds, comment indicates what type of pattern is used. +// supported, comments should be "unsupported". If SIMD is supported but +// fails, comment should be "no simd [in ] because ." +// When simd succeeds, comment indicates what type of pattern is used. inline void onnxToKrnlSimdReport(mlir::Operation *op, bool successful = false, int64_t vectorLength = 0, int64_t simdLoopTripCount = 0, const std::string &comment = "") { @@ -667,4 +740,12 @@ inline void onnxToKrnlSimdReport(mlir::Operation *op, bool successful = false, op, successful, vectorLength, simdLoopTripCount, comment); } +// Compute the min and max of input, allocate and save the results into +// minAlloc and maxAlloc. +void emitMinMaxReductionToScalar(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc, mlir::Operation *op, mlir::Value input, + mlir::Value &minAlloc, mlir::Value &maxAlloc, bool enableSIMD, + bool enableParallel); + } // namespace onnx_mlir +#endif diff --git a/src/Conversion/ONNXToKrnl/PerfectHash.hpp b/src/Conversion/ONNXToKrnl/PerfectHash.hpp index c2f42ff746..b66339d333 100644 --- a/src/Conversion/ONNXToKrnl/PerfectHash.hpp +++ b/src/Conversion/ONNXToKrnl/PerfectHash.hpp @@ -4,7 +4,7 @@ //====--------------- PerfectHash.hpp - Perfect Hash Table ----------------===// // -// Copyright 2021 The IBM Research Authors. +// Copyright 2021-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_PERFECT_HASH_H +#define ONNX_MLIR_PERFECT_HASH_H #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -57,3 +58,4 @@ class PerfectHash { }; } // namespace onnx_mlir +#endif diff --git a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp index 9f485b984d..fdebe15d86 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp @@ -4,7 +4,7 @@ //===--- DynamicQuantizeLinear.cpp - Lowering DynamicQuantizeLinear Op ----===// // -// Copyright 2023 The IBM Research Authors. +// Copyright 2023-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,26 +13,79 @@ //===----------------------------------------------------------------------===// #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Support/SmallVectorHelper.hpp" using namespace mlir; namespace onnx_mlir { -// TODO may consider SIMD and parallel. +// Implementation of quantize helper function. +// TODO: add parallel. +void emitDynamicQuantizationLinearScalarParameters( + ConversionPatternRewriter &rewriter, Location loc, Operation *op, + MemRefType inputType, MemRefType quantizedType, Value input, Value qMin, + Value qMax, Value &scale, Value &zeroPoint, Value &quantizedZeroPoint, + bool enableSIMD, bool enableParallel) { + MultiDialectBuilder create(rewriter, loc); + + // Types + Type elementType = inputType.getElementType(); + Type quantizedElementType = quantizedType.getElementType(); + + // Equations: + // y_scale = (max(x) - min(x))/(qMax - qMin) + // intermediate_zero_point = qMin - min(x)/y_scale + // y_zero_point = cast(round(saturate(intermediate_zero_point))) + // y = saturate (round (x / y_scale) + y_zero_point) + // + // where, saturate is to clip to [0, 255] for ui8. + + Value inputMinAlloc, inputMaxAlloc; + emitMinMaxReductionToScalar(rewriter, loc, op, input, inputMinAlloc, + inputMaxAlloc, enableSIMD, enableParallel); + Value xMin = create.krnl.load(inputMinAlloc); + Value xMax = create.krnl.load(inputMaxAlloc); + + // Include 0 to max(x) and min(x). + // x_min = min(min(x), 0) + // x_max = max(max(x), 0) + Value zero = create.math.constant(elementType, 0.0); + xMax = create.math.max(xMax, zero); + xMin = create.math.min(xMin, zero); + // Compute y_scale. + Value xDiff = create.math.sub(xMax, xMin); + Value boundDiff = create.math.sub(qMax, qMin); + scale = create.math.div(xDiff, boundDiff); + + // Compute y_zero_point. + Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale)); + // Saturate zero point. + Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax); + // Round zero point. + zeroPoint = create.math.round(saturateZeroPoint); + quantizedZeroPoint = create.math.cast(quantizedElementType, zeroPoint); +} + struct ONNXDynamicQuantizeLinearOpLowering : public OpConversionPattern { - ONNXDynamicQuantizeLinearOpLowering( - TypeConverter &typeConverter, MLIRContext *ctx) - : OpConversionPattern(typeConverter, ctx) {} + ONNXDynamicQuantizeLinearOpLowering(TypeConverter &typeConverter, + MLIRContext *ctx, bool enableSIMD, bool enableParallel) + : OpConversionPattern(typeConverter, ctx), enableSIMD(enableSIMD), + enableParallel(enableParallel) {} + + bool enableSIMD = false; + bool enableParallel = false; + + using LocalDialectBuilder = MultiDialectBuilder; LogicalResult matchAndRewrite(ONNXDynamicQuantizeLinearOp dqlOp, ONNXDynamicQuantizeLinearOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - using LocalDialectBuilder = MultiDialectBuilder; Operation *op = dqlOp.getOperation(); Location loc = ONNXLoc(op); LocalDialectBuilder create(rewriter, loc); @@ -51,8 +104,6 @@ struct ONNXDynamicQuantizeLinearOpLowering // Types Type elementType = xMemRefType.getElementType(); - Type quantizedElementType = yMemRefType.getElementType(); - int64_t rank = xMemRefType.getRank(); // Get shape. ONNXDynamicQuantizeLinearOpShapeHelper shapeHelper( @@ -67,76 +118,19 @@ struct ONNXDynamicQuantizeLinearOpLowering Value YZeroPoint = create.mem.alignedAlloc( yZeroPointMemRefType, shapeHelper.getOutputDims(2)); - // TODO: consider SIMD version of this. - - // Equations: - // y_scale = (max(x) - min(x))/(qmax - qmin) - // intermediate_zero_point = qmin - min(x)/y_scale - // y_zero_point = cast(round(saturate(itermediate_zero_point))) - // y = saturate (round (x / y_scale) + y_zero_point) - // - // where, saturate is to clip to [0, 255] for ui8. - - // QMax, QMin. Value qMax = create.math.constant(elementType, 255.0); Value qMin = create.math.constant(elementType, 0.0); - Value QMax = create.mem.alignedAlloc(yScaleMemRefType); - create.krnl.store(qMax, QMax); - Value QMin = create.mem.alignedAlloc(yScaleMemRefType); - create.krnl.store(qMin, QMin); - - // Compute max(x) and min (x). - Value none = create.onnx.none(); - Value XMax = create.onnx.toMemref( - create.onnx.reduceMax(yScaleMemRefType, X, none, false)); - Value XMin = create.onnx.toMemref( - create.onnx.reduceMin(yScaleMemRefType, X, none, false)); - Value xMax = create.krnl.load(XMax); - Value xMin = create.krnl.load(XMin); - // Include 0 to max(x) and min(x). - // x_min = min(min(x), 0) - // x_max = max(max(x), 0) - Value zero = create.math.constant(elementType, 0.0); - Value greaterThanZero = create.math.sgt(xMax, zero); - xMax = create.math.select(greaterThanZero, xMax, zero); - Value lessThanZero = create.math.slt(xMin, zero); - xMin = create.math.select(lessThanZero, xMin, zero); - - // Compute y_scale. - Value scale = create.math.div( - create.math.sub(xMax, xMin), create.math.sub(qMax, qMin)); - create.krnl.store(scale, YScale); + Value scale, zeroPoint, zeroPointInt; - // Compute y_zero_point. - Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale)); - // Saturate zero point. - Value saturateZeroPoint = - create.onnx.clip(interZeroPoint, qMin, qMax, /*scalarType=*/true); - // Round zero point. - Value zeroPoint = create.onnx.round(saturateZeroPoint, /*scalarType=*/true); - Value zeroPointInt = create.math.cast(quantizedElementType, zeroPoint); + emitDynamicQuantizationLinearScalarParameters(rewriter, loc, op, + xMemRefType, yMemRefType, X, qMin, qMax, scale, zeroPoint, zeroPointInt, + enableSIMD, enableParallel); + create.krnl.store(scale, YScale); create.krnl.store(zeroPointInt, YZeroPoint); - // Compute y. - ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); - create.krnl.iterateIE(loopDef, loopDef, lbs, shapeHelper.getOutputDims(0), - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { - MultiDialectBuilder create( - createKrnl); - Value x = create.krnl.load(X, loopInd); - // Scale - Value scaleX = create.math.div(x, scale); - // Round - Value roundX = create.onnx.round(scaleX, /*scalarType=*/true); - // Adjust - Value adjustX = create.math.add(roundX, zeroPoint); - // Saturate - Value saturateX = - create.onnx.clip(adjustX, qMin, qMax, /*scalarType=*/true); - Value res = create.math.cast(quantizedElementType, saturateX); - create.krnl.store(res, Y, loopInd); - }); + emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, + yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale, + zeroPoint, enableSIMD, enableParallel); rewriter.replaceOp(op, {Y, YScale, YZeroPoint}); onnxToKrnlSimdReport(op); @@ -145,9 +139,10 @@ struct ONNXDynamicQuantizeLinearOpLowering }; void populateLoweringONNXDynamicQuantizeLinearOpPattern( - RewritePatternSet &patterns, TypeConverter &typeConverter, - MLIRContext *ctx) { - patterns.insert(typeConverter, ctx); + RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, + bool enableSIMD, bool enableParallel) { + patterns.insert( + typeConverter, ctx, enableSIMD, enableParallel); } } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp new file mode 100644 index 0000000000..124b854bde --- /dev/null +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//==--- QuantizeHelper.hpp - Helper functions for Quantization Op lowering --=// +// +// Copyright 2023-2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains definitions of helper functions for quantization lowering. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +namespace onnx_mlir { + +// Given an input, scale, zero point, qMin, and qMax, perform a linear +// quantization and store in alloc. +void emitQuantizationLinearScalarParameters( + mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, + mlir::Operation *op, mlir::MemRefType inputType, + mlir::MemRefType quantizedType, mlir::Value alloc, DimsExpr &allocDims, + mlir::Value input, mlir::Value qMin, mlir::Value qMax, mlir::Value scale, + mlir::Value zeroPoint, bool enableSIMD, bool enableParallel); + +// Scan the input to compute scale, zeroPoint, and quantizedZeroPoint given qMin +// and qMax. +void emitDynamicQuantizationLinearScalarParameters( + mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, + mlir::Operation *op, mlir::MemRefType inputType, + mlir::MemRefType quantizedType, mlir::Value input, mlir::Value qMin, + mlir::Value qMax, mlir::Value &scale, mlir::Value &zeroPoint, + mlir::Value &quantizedZeroPoint, bool enableSIMD, bool enableParallel); +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 07151fbad7..83b2094fc7 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -4,7 +4,7 @@ //===-------- QuantizeLinear.cpp - Lowering QuantizeLinear Op -------------===// // -// Copyright 2023 The IBM Research Authors. +// Copyright 2023-2024 The IBM Research Authors. // // ============================================================================= // @@ -16,21 +16,97 @@ #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Support/SmallVectorHelper.hpp" using namespace mlir; namespace onnx_mlir { +// Helper function for quantization. +void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, MemRefType inputType, MemRefType quantizedType, + Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax, + Value scale, Value zeroPoint, bool enableSIMD, bool enableParallel) { + MultiDialectBuilder create( + rewriter, loc); + + // Types + Type quantizedElementType = quantizedType.getElementType(); + int64_t rank = inputType.getRank(); + + // Flatten the input data and outputs + DimsExpr inputDims, flatInputDims, flatAllocDims; + inputDims = allocDims; // Unput and output have the same shape. + Value flatInput = + create.mem.reshapeToFlatInnermost(input, inputDims, flatInputDims, rank); + Value flatAlloc = + create.mem.reshapeToFlatInnermost(alloc, allocDims, flatAllocDims, rank); + + // Determine a suitable SIMD vector length for this loop. + int64_t totVL = 1; + int64_t simdLoopStaticTripCount = 0; + bool simdOnly = false; + if (enableSIMD) { + int64_t innermostLoopCollapse = 1; // Only innermost is simdized. + bool canOverCompute = false; + GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5}, + {GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2}, + {GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3}, + {GenericOps::FloorGop, 2}}; + totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/, + innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount, + simdOnly); + } + + IndexExpr zero = LitIE(0); + IndexExpr simdLb = zero; + IndexExpr simdUb = flatAllocDims[0]; + // Create access functions for input X and output Y. + DimsExpr inputAF; + inputAF.emplace_back(zero); + DimsExpr outputAF; + outputAF.emplace_back(zero); + create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, + {flatInput}, {inputAF}, {flatAlloc}, {outputAF}, + [&](KrnlBuilder &kb, ArrayRef inputVals, + SmallVectorImpl &resVals, int64_t VL) { + MultiDialectBuilder create(kb); + Value x = inputVals[0]; + // Scale + Value scaleX = create.math.div(x, scale); + // Round + Value roundX = create.math.round(scaleX); + // Adjust + Value adjustX = create.math.add(roundX, zeroPoint); + // Saturate + Value saturateX = create.math.clip(adjustX, qMin, qMax); + Value res = create.math.cast(quantizedElementType, saturateX); + resVals.emplace_back(res); + }); + if (totVL > 1) + onnxToKrnlSimdReport(op, /*successful*/ true, totVL, + simdLoopStaticTripCount, "quantizationLinear whole tensor"); + else + onnxToKrnlSimdReport(op, /*successful*/ false, 0, 0, + "no simd in quantizationLinear whole tensor"); +} + struct ONNXQuantizeLinearOpLowering : public OpConversionPattern { - ONNXQuantizeLinearOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) - : OpConversionPattern(typeConverter, ctx) {} + ONNXQuantizeLinearOpLowering(TypeConverter &typeConverter, MLIRContext *ctx, + bool enableSIMD, bool enableParallel) + : OpConversionPattern(typeConverter, ctx), enableSIMD(enableSIMD), + enableParallel(enableParallel) {} + + bool enableSIMD = false; + bool enableParallel = false; + + using LocalDialectBuilder = MultiDialectBuilder; LogicalResult matchAndRewrite(ONNXQuantizeLinearOp qlOp, ONNXQuantizeLinearOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - using LocalDialectBuilder = MultiDialectBuilder; Operation *op = qlOp.getOperation(); Location loc = ONNXLoc(op); LocalDialectBuilder create(rewriter, loc); @@ -49,7 +125,6 @@ struct ONNXQuantizeLinearOpLowering // Types Type elementType = xMemRefType.getElementType(); Type quantizedElementType = yMemRefType.getElementType(); - int64_t rank = xMemRefType.getRank(); // Does not support per-axis and i8. assert(yScaleMemRefType.getRank() == 0 && @@ -91,28 +166,9 @@ struct ONNXQuantizeLinearOpLowering } else zeroPoint = create.math.constant(elementType, 0.0); - // Compute y. - ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); - create.krnl.iterateIE(loopDef, loopDef, lbs, shapeHelper.getOutputDims(0), - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { - MultiDialectBuilder create( - createKrnl); - Value x = create.krnl.load(X, loopInd); - // Scale - Value scaleX = create.math.div(x, scale); - // Round - Value roundX = create.onnx.round(scaleX, /*scalarType=*/true); - // Adjust - Value adjustX = create.math.add(roundX, zeroPoint); - // Saturate - Value lessThanMin = create.math.slt(adjustX, qMin); - Value saturateX = create.math.select(lessThanMin, qMin, adjustX); - Value lessThanMax = create.math.slt(saturateX, qMax); - saturateX = create.math.select(lessThanMax, saturateX, qMax); - Value res = create.math.cast(quantizedElementType, saturateX); - create.krnl.store(res, Y, loopInd); - }); + emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, + yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale, + zeroPoint, enableSIMD, enableParallel); rewriter.replaceOp(op, {Y}); onnxToKrnlSimdReport(op); @@ -121,8 +177,10 @@ struct ONNXQuantizeLinearOpLowering }; void populateLoweringONNXQuantizeLinearOpPattern(RewritePatternSet &patterns, - TypeConverter &typeConverter, MLIRContext *ctx) { - patterns.insert(typeConverter, ctx); + TypeConverter &typeConverter, MLIRContext *ctx, bool enableSIMD, + bool enableParallel) { + patterns.insert( + typeConverter, ctx, enableSIMD, enableParallel); } } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp index 4c5f179dbd..c418c1d002 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp @@ -4,7 +4,7 @@ //===--------------- RNNBase.hpp - Lowering RNN Ops -----------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_RNN_BASE_KRNL_H +#define ONNX_MLIR_RNN_BASE_KRNL_H #include "mlir/IR/AffineExpr.h" @@ -228,3 +229,4 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { }; } // namespace onnx_mlir +#endif diff --git a/src/Conversion/ONNXToStablehlo/CMakeLists.txt b/src/Conversion/ONNXToStablehlo/CMakeLists.txt index fe2c842b21..92d25fdd93 100644 --- a/src/Conversion/ONNXToStablehlo/CMakeLists.txt +++ b/src/Conversion/ONNXToStablehlo/CMakeLists.txt @@ -47,6 +47,7 @@ add_onnx_mlir_library(OMONNXToStablehlo Math/Gemm.cpp Math/MatMul.cpp Math/Reduction.cpp + Math/Softmax.cpp NN/Conv.cpp NN/ConvTranspose.cpp NN/Normalization.cpp diff --git a/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp b/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp index 1550214d60..0bcc199506 100644 --- a/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp +++ b/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp @@ -29,6 +29,7 @@ void populateONNXToStablehloConversionPattern( populateLoweringONNXGemmOpToStablehloPattern(patterns, ctx); populateLoweringONNXMatMulOpToStablehloPattern(patterns, ctx); populateLoweringONNXReductionOpToStablehloPattern(patterns, ctx); + populateLoweringONNXSoftmaxOpToStablehloPattern(patterns, ctx); // Neural network populateLoweringONNXConvOpToStablehloPattern(patterns, ctx); populateLoweringONNXConvTransposeOpToStablehloPattern(patterns, ctx); @@ -126,9 +127,6 @@ void FrontendToStablehloLoweringPass::runOnOperation() { populateONNXToStablehloConversionPattern( patterns, &getContext(), enableUnroll); - // add illegal op - target.addIllegalOp(); - // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. diff --git a/src/Conversion/ONNXToStablehlo/DialectBuilder.hpp b/src/Conversion/ONNXToStablehlo/DialectBuilder.hpp index 65e8d55675..6c16e53af9 100644 --- a/src/Conversion/ONNXToStablehlo/DialectBuilder.hpp +++ b/src/Conversion/ONNXToStablehlo/DialectBuilder.hpp @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DIALECT_BUILDER_STABLEHLO_H +#define ONNX_MLIR_DIALECT_BUILDER_STABLEHLO_H #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" @@ -158,3 +159,4 @@ struct MultiDialectBuilder }; } // namespace onnx_mlir +#endif diff --git a/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp b/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp new file mode 100644 index 0000000000..b7f2214c02 --- /dev/null +++ b/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp @@ -0,0 +1,186 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Softmax.cpp - Softmax Ops -------------------===// +// +// Copyright 2022-2024 +// +// ============================================================================= +// +// This file lowers ONNX softmax operators to Stablehlo dialect. +// + +#include "src/Conversion/ONNXToStablehlo/DialectBuilder.hpp" +#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp" +#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Support/TypeUtilities.hpp" +#include "stablehlo/dialect/BroadcastUtils.h" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +Value getReductionShapeValue(Location loc, PatternRewriter &rewriter, + Value operand, llvm::SmallVector axes, bool keepDims) { + int64_t rank = mlir::cast(operand.getType()).getRank(); + + Value inputShape = rewriter.create(loc, operand); + SmallVector dims; + for (int64_t i = 0; i < rank; i++) { + if (!(std::find(axes.begin(), axes.end(), i) != axes.end())) { + Value dim = rewriter.create(loc, inputShape, i); + dims.push_back(dim); + } else if (keepDims) { + Value dim = rewriter.create(loc, 1); + dims.push_back(dim); + } + } + Value reduceShapeValue = rewriter.create(loc, dims); + reduceShapeValue = rewriter.create(loc, + RankedTensorType::get({rank}, rewriter.getIndexType()), reduceShapeValue); + return reduceShapeValue; +} + +// Calutes Broadcast dimensions +SmallVector getBroadcastDims( + Value operand, llvm::SmallVector axes) { + int64_t rank = mlir::cast(operand.getType()).getRank(); + SmallVector dims; + for (int64_t i = 0; i < rank; i++) { + if (!(std::find(axes.begin(), axes.end(), i) != axes.end())) { + dims.push_back(i); + } + } + + return dims; +} + +Value computeReduceSum(Location loc, Value operand, Value identity, + SmallVector &reduceShape, llvm::SmallVector axes, + PatternRewriter &rewriter, bool keepDims, ShapedType outputType) { + + RankedTensorType operandType = + mlir::cast(operand.getType()); + Type reduceResultType = + RankedTensorType::get(reduceShape, operandType.getElementType()); + stablehlo::ReduceOp reduce = rewriter.create(loc, + reduceResultType, operand, identity, rewriter.getDenseI64ArrayAttr(axes)); + + Region ®ion = reduce.getBody(); + Block &block = region.emplaceBlock(); + RankedTensorType blockArgumentType = + RankedTensorType::get({}, operandType.getElementType()); + block.addArgument(blockArgumentType, loc); + block.addArgument(blockArgumentType, loc); + + BlockArgument firstArgument = *block.args_begin(); + BlockArgument secondArgument = *block.args_rbegin(); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value reduceResult = + rewriter.create(loc, firstArgument, secondArgument); + rewriter.create(loc, reduceResult); + } + Value result = reduce.getResult(0); + + if (keepDims) { + Value reduceShapeValue = + getReductionShapeValue(loc, rewriter, operand, axes, true); + result = rewriter.create( + loc, outputType, result, reduceShapeValue); + } + return result; +} + +SmallVector getReductionShape(ShapedType inputType, + const llvm::SmallVector &axes, bool isKeepdims) { + SmallVector reduceShape; + llvm::ArrayRef inputShape = inputType.getShape(); + int64_t rank = inputType.getRank(); + + // Mark reduction axes. + for (int64_t i = 0; i < rank; ++i) { + if (!(std::find(axes.begin(), axes.end(), i) != axes.end())) + reduceShape.push_back(inputShape[i]); + else if (isKeepdims) + reduceShape.push_back(1); + } + + return reduceShape; +} + +struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern { + ONNXSoftmaxOpLoweringToStablehlo(MLIRContext *ctx) + : ConversionPattern(ONNXSoftmaxOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + + Value operand = operands[0]; + assert( + hasStaticShape(operand.getType()) && "Only Static shapes are accepted"); + + Location loc = op->getLoc(); + Type outputType = *op->result_type_begin(); + assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); + assert(mlir::cast(operand.getType()) + .getElementType() + .isF32() && + "Currently Only float32 is supported for input"); + + // Exponential operation + Value ElementwiseExpStableHLO = rewriter.create( + loc, op->getResultTypes(), op->getOperands()); + + if (ElementwiseExpStableHLO == nullptr) + return failure(); + + RankedTensorType ExpOutputType = + mlir::cast(ElementwiseExpStableHLO.getType()); + + // Converting negative indices to Postive indices + int64_t axis = mlir::cast(*op).getAxis(); + if (axis < 0) + axis = ExpOutputType.getRank() + axis; + + SmallVector axes = {axis}; + // Sum of the all the exponents for the denominator + SmallVector reducedShape = + getReductionShape(ExpOutputType, axes, false); + ShapedType ReducedShapeType = mlir::cast( + RankedTensorType::get(reducedShape, ExpOutputType.getElementType())); + Value identity = rewriter.create( + loc, rewriter.getZeroAttr(ExpOutputType.getElementType())); + Value ReduceSum = computeReduceSum(loc, ElementwiseExpStableHLO, identity, + reducedShape, axes, rewriter, false, ReducedShapeType); + if (ReduceSum == nullptr) + return failure(); + + SmallVector broadcast_dims = + getBroadcastDims(ElementwiseExpStableHLO, axes); + Value BroadCastOp = + rewriter.create(loc, ExpOutputType, + ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims)); + if (BroadCastOp == nullptr) + return failure(); + + Value Softmax_output = rewriter.create( + loc, ElementwiseExpStableHLO, BroadCastOp); + if (Softmax_output == nullptr) + return failure(); + + rewriter.replaceOp(op, Softmax_output); + return success(); + } +}; +} // namespace + +void populateLoweringONNXSoftmaxOpToStablehloPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp index c2b99f8e02..265a52b4d7 100644 --- a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp +++ b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_TO_STABLEHLO_H +#define ONNX_MLIR_ONNX_TO_STABLEHLO_H #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -215,4 +216,7 @@ void populateLoweringONNXTransposeOpToStablehloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXUnsqueezeOpToStablehloPattern( RewritePatternSet &, MLIRContext *); +void populateLoweringONNXSoftmaxOpToStablehloPattern( + RewritePatternSet &patterns, MLIRContext *ctx); } // namespace onnx_mlir +#endif diff --git a/src/Conversion/ONNXToStablehlo/RNN/RNNBase.hpp b/src/Conversion/ONNXToStablehlo/RNN/RNNBase.hpp index c065c77218..4de86a67ea 100644 --- a/src/Conversion/ONNXToStablehlo/RNN/RNNBase.hpp +++ b/src/Conversion/ONNXToStablehlo/RNN/RNNBase.hpp @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_RNN_BASE_STABLEHLO_H +#define ONNX_MLIR_RNN_BASE_STABLEHLO_H #include "src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp" #include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp" @@ -177,3 +178,4 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { } // namespace stablehlo } // namespace onnx_mlir +#endif diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp index 3087dc5e28..1050d97053 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DIALECT_BUILDER_TOSA_H +#define ONNX_MLIR_DIALECT_BUILDER_TOSA_H #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Builders.h" @@ -115,3 +116,4 @@ struct MultiDialectBuilder }; } // namespace onnx_mlir +#endif diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp index 88f47497a7..d5ef2d5053 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -4,7 +4,7 @@ //====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===// // -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// Copyright 2020-2024 The TensorFlow Authors. All Rights Reserved. // Copyright (c) 2022-2023 Advanced Micro Devices, Inc. // // ============================================================================= @@ -14,7 +14,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_TO_TOSA_H +#define ONNX_MLIR_ONNX_TO_TOSA_H #include "DialectBuilder.hpp" #include "ONNXToTOSALegalizeUtils.hpp" @@ -124,3 +125,4 @@ void populateLoweringONNXReshapeOpToTOSAPattern(mlir::ConversionTarget &, void populateLoweringONNXResizeOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); } // namespace onnx_mlir +#endif diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index 99198a0182..59a624ad57 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -4,7 +4,7 @@ //====-------------- DialectBuilder.cpp - Krnl Dialect Builder ------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -14,6 +14,7 @@ #include "llvm/ADT/TypeSwitch.h" +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -56,6 +57,16 @@ static StringRef getFormat(const Type &inputType) { //====---------------- Support for Krnl Builder ----------------------===// Value KrnlBuilder::load(Value memref, ValueRange indices) const { + if (indices.size() == 0) { + // case memref<1xdtype> + MemRefType type = dyn_cast_or_null(memref.getType()); + assert(type && "Not MemRefType"); + if (type.getRank() == 1 && type.getShape()[0] == 1) { + MultiDialectBuilder create(*this); + Value iZero = create.math.constantIndex(0); + return b().create(loc(), memref, ValueRange({iZero})); + } + } return b().create(loc(), memref, indices); } @@ -68,12 +79,33 @@ mlir::Value KrnlBuilder::load(mlir::Value memref, mlir::ValueRange indices, } Value KrnlBuilder::loadIE(Value memref, ArrayRef indices) const { + if (indices.size() == 0) { + // case memref<1xdtype> + MemRefType type = dyn_cast_or_null(memref.getType()); + assert(type && "Not MemRefType"); + if (type.getRank() == 1 && type.getShape()[0] == 1) { + MultiDialectBuilder create(*this); + Value iZero = create.math.constantIndex(0); + return b().create(loc(), memref, ValueRange({iZero})); + } + } SmallVector indexValues; IndexExpr::getValues(indices, indexValues); return b().create(loc(), memref, indexValues); } void KrnlBuilder::store(Value val, Value memref, ValueRange indices) const { + if (indices.size() == 0) { + // case memref<1xdtype> + MemRefType type = dyn_cast_or_null(memref.getType()); + assert(type && "Not MemRefType"); + if (type.getRank() == 1 && type.getShape()[0] == 1) { + MultiDialectBuilder create(*this); + Value iZero = create.math.constantIndex(0); + b().create(loc(), val, memref, ValueRange({iZero})); + return; + } + } b().create(loc(), val, memref, indices); } @@ -87,6 +119,17 @@ void KrnlBuilder::store(mlir::Value val, mlir::Value memref, void KrnlBuilder::storeIE( Value val, Value memref, ArrayRef indices) const { + if (indices.size() == 0) { + // case memref<1xdtype> + MemRefType type = dyn_cast_or_null(memref.getType()); + assert(type && "Not MemRefType"); + if (type.getRank() == 1 && type.getShape()[0] == 1) { + MultiDialectBuilder create(*this); + Value iZero = create.math.constantIndex(0); + b().create(loc(), val, memref, ValueRange({iZero})); + return; + } + } SmallVector indexValues; IndexExpr::getValues(indices, indexValues); b().create(loc(), val, memref, indexValues); @@ -212,7 +255,7 @@ void KrnlBuilder::iterateIE(ValueRange originalLoops, ValueRange optimizedLoops, iterateIE(originalLoops, optimizedLoops, lbs, ubs, {}, bodyBuilderFnWrapper); } -mlir::KrnlIterateOp KrnlBuilder::iterateIE(ValueRange originalLoops, +KrnlIterateOp KrnlBuilder::iterateIE(ValueRange originalLoops, ValueRange optimizedLoops, ArrayRef lbs, ArrayRef ubs, mlir::ValueRange inits, function_ref inputVals, + SmallVectorImpl &resVals, int64_t VL) { + MultiDialectBuilder create(kb); + Value aVal = inputVals[0]; // simd or scalar + Value bVal = inputVals[1]; // simd or scalar + Value cVal = create.krnl.load(C, {}); // scalar always + Value newVal = create.math.add(aVal, bVal); // simd or scalar + newVal = create.math.add(newVal, cVal); // if newVal is simd, cVal is + // splatted + res.emplace_back(newVal); // Save simd or scalar result. + } + + The krnl.simdIterateIE will be in charge of loading and saving the values in + memory. The create.math functions have been extended so that when a SIMD + value is computed with a scalar, that scalar will be automaticaly splatted + (aka promoted to a vector of identical values). As a result, the kernel can + be written in a SIMD agnostic value. However, in rare situations, we may + want to know if we are in SIMD mode or not. VL will give the totVL used here + (either totVL>1 or 1). +*/ + +// Determine if an access has one element from the innermost dimensions up to +// innerDim. +bool static hasOneElementInInnermostDims(Value value, int64_t innerDim) { + // Get info. + ShapedType type = mlir::dyn_cast(value.getType()); + assert(type && "expected shaped type"); + int64_t rank = type.getRank(); + mlir::ArrayRef shape = type.getShape(); + for (int64_t i = std::max((int64_t)0, rank - innerDim); i < rank; ++i) + if (shape[i] != 1) + return false; + return true; +} + +void KrnlBuilder::simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, + bool fullySimd, bool useParallel, ArrayRef inputs, + ArrayRef inputAFs, ArrayRef outputs, + ArrayRef outputAFs, + function_ref inputVals, + llvm::SmallVectorImpl &resultVals, int64_t VL)> + bodyBuilderFn) { + int64_t inputNum = inputs.size(); + assert(inputAFs.size() == inputs.size() && "expected same size"); + int64_t outputNum = outputs.size(); + assert(outputAFs.size() == outputs.size() && "expected same size"); + MultiDialectBuilder create(*this); + + if (VL > 1) { + // Want SIMD, execute full SIMD loops blocked by VL. + ValueRange loopDef = defineLoops(1); + ValueRange blockedLoopDef = block(loopDef[0], VL); + if (useParallel) + parallel({blockedLoopDef[0]}); + + // If we are not guaranteed that every iterations are SIMD iterations, then + // we need to reduce the trip count by a bit so as to not over compute. + // If we are not guaranteed that every iterations are SIMD iterations, then + IndexExpr simdUb = ub; + if (!fullySimd) + simdUb = simdUb - (VL - 1); + iterateIE(loopDef, {blockedLoopDef[0]}, {lb}, {simdUb}, + [&](KrnlBuilder &ck, ValueRange loopInd) { + IndexExprScope scope(ck); + MultiDialectBuilder create(ck); + IndexExpr ind = DimIE(loopInd[0]); + // Load all the inputs as vectors of VL values, with a few exceptions. + // One is if the value is a "none value", leave as is. Another one is + // if the innermost dim is a scalar (ie dim[rank-1] == 1), then we + // just load the scalar. + llvm::SmallVector vecInputVals; + for (int64_t i = 0; i < inputNum; ++i) { + Value input = inputs[i]; + if (isNoneValue(input)) { + // Simply enqueue the none value. + vecInputVals.emplace_back(input); + continue; + } + MemRefType type = mlir::cast(input.getType()); + int64_t rank = type.getRank(); + DimsExpr AF = SymListIE(inputAFs[i]); + assert(rank == (int64_t)AF.size() && "AF expected input rank refs"); + if (hasOneElementInInnermostDims(input, 1)) { + // Has a reference with a scalar innermost dim, just load as a + // scalar. No need to add the induction variable. + Value scalarVal = create.krnl.loadIE(input, AF); + vecInputVals.emplace_back(scalarVal); + } else { + // Have a vector. + VectorType vecType = VectorType::get({VL}, type.getElementType()); + AF[rank - 1] = AF[rank - 1] + ind; // Add induction var. + Value vecVal = create.vec.loadIE(vecType, input, AF, {}); + vecInputVals.emplace_back(vecVal); + } + } + // Call the method to compute the values. + llvm::SmallVector vecResVals; + bodyBuilderFn(create.krnl, vecInputVals, vecResVals, VL); + assert((int64_t)vecResVals.size() == outputNum && + "loop body with incorrect number of results"); + // Store all the outputs as vectors of VL values, + for (int64_t i = 0; i < outputNum; ++i) { + MemRefType type = mlir::cast(outputs[i].getType()); + DimsExpr AF = SymListIE(outputAFs[i]); + int64_t rank = type.getRank(); + assert(rank == (int64_t)AF.size() && "AF expected ouput rank refs"); + AF[rank - 1] = AF[rank - 1] + ind; + create.vec.storeIE(vecResVals[i], outputs[i], AF, {}); + } + }); + if (fullySimd) + // Asserted that we only have SIMD iterations, we are done. + return; + // Account for the loop iterations performed above. + IndexExpr tripCount = ub - lb; + IndexExpr missingIters = tripCount % VL; + IndexExpr completedIters = tripCount - missingIters; + if (missingIters.isLiteralAndIdenticalTo(0)) { + // Detect that we only have SIMD iterations, we are also done. + return; + } + // We may have additional iterations to perform, adjust lb to skip the + // completed iterations. + lb = lb + completedIters; + } + // Handle remaining scalar values (from lb to ub without unrolling). + ValueRange loopDef = defineLoops(1); + iterateIE( + loopDef, loopDef, {lb}, {ub}, [&](KrnlBuilder &ck, ValueRange loopInd) { + IndexExprScope scope(ck); + MultiDialectBuilder create(ck); + IndexExpr ind = DimIE(loopInd[0]); + // Load all the inputs as scalar values, + llvm::SmallVector scalarInputVals; + for (int64_t i = 0; i < inputNum; ++i) { + Value input = inputs[i]; + if (isNoneValue(input)) { + // Simply enqueue the none value. + scalarInputVals.emplace_back(input); + continue; + } + MemRefType type = mlir::cast(input.getType()); + int64_t rank = type.getRank(); + DimsExpr AF = SymListIE(inputAFs[i]); + if (hasOneElementInInnermostDims(input, 1)) { + // Has a reference with a scalar innermost dim, just load as a + // scalar. No need to add the induction variable. + Value scalarVal = create.krnl.loadIE(input, AF); + scalarInputVals.emplace_back(scalarVal); + } else { + AF[rank - 1] = AF[rank - 1] + ind; + Value scalarVal = create.krnl.loadIE(input, AF); + scalarInputVals.emplace_back(scalarVal); + } + } + // Call the method to compute the values. + llvm::SmallVector scalarResVals; + bodyBuilderFn(create.krnl, scalarInputVals, scalarResVals, /*VL*/ 1); + assert((int64_t)scalarResVals.size() == outputNum && + "loop body with incorrect number of results"); + // Store all the outputs as vectors of VL values, + for (int64_t i = 0; i < outputNum; ++i) { + MemRefType type = mlir::cast(outputs[i].getType()); + DimsExpr AF = SymListIE(outputAFs[i]); + int64_t rank = type.getRank(); + assert(rank == (int64_t)AF.size() && "AF expected ouput rank refs"); + AF[rank - 1] = AF[rank - 1] + ind; + create.krnl.storeIE(scalarResVals[i], outputs[i], AF); + } + }); +} + void KrnlBuilder::yield(mlir::ValueRange iterArgs) const { b().create(loc(), iterArgs); } diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp index 1a50768542..85c40f9b9d 100644 --- a/src/Dialect/Krnl/DialectBuilder.hpp +++ b/src/Dialect/Krnl/DialectBuilder.hpp @@ -4,7 +4,7 @@ //====--------- DialectBuilder.hpp - Krnl Dialect Builder -----------------===// // -// Copyright 2022-2023 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DIALECT_BUILDER_KRNL_H +#define ONNX_MLIR_DIALECT_BUILDER_KRNL_H #include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" @@ -69,7 +70,9 @@ struct KrnlBuilder : public DialectBuilder { mlir::ValueRange getInductionVarValue(mlir::ValueRange loops) const; void parallel(mlir::ValueRange loops) const; - // Lambda passes loop indices as 2nd parameter. + // Iterate over optimized loops given the original loops, lbs and ubs. Lambda + // function implement the body of the loop, and receive a KRNL builder and the + // loop indices. void iterate(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops, mlir::ValueRange lbs, mlir::ValueRange ubs, mlir::function_ref lbs, mlir::ArrayRef ubs, @@ -98,6 +101,41 @@ struct KrnlBuilder : public DialectBuilder { mlir::ValueRange blockIters)> bodyBuilderFn) const; + // Iterate over a loop executing the loop body in SIMD mode (of vector length + // VL) from lb to ub. A scalar loop may execute up to VL-1 loop + // iterations when the trip count is not a multiple of VL. If fullySimd is + // true, then the call assumes that the trip count is a multiple of VL. + // + // This call needs be given each of the memref inputs to the loop body, given + // as an ordered pair memref value and its corresponding access function. Same + // hold for all the memref outputs of the loop body. + // + // The loop body is given a KRNL builder, a list of loaded input (same order + // as the input's memrefs and access functions). It will generate values that + // must be placed in the result list in the same order as the output's memrefs + // and access functions. + // + // It will be the responsibility of this call to load each of the inputs and + // store each of the outputs. When operating in SIMD mode, every input and + // output values are vectors of length VL. In scalar mode, they are simply + // scalar values. + // + // SIMD is exploited in the innermost dimension of each access function. + // This call is only applicable to loop bodies where every input/output is + // strided in its innermost dimension. Inputs can also be loop invariant + // (scalar), in term of the loop being iterated on. + // + // If useParallel is true, then the blocked SIMD loop is executed in parallel. + + void simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + bool useParallel, mlir::ArrayRef inputs, + mlir::ArrayRef inputAFs, mlir::ArrayRef outputs, + mlir::ArrayRef outputAFs, + mlir::function_ref inputVals, + llvm::SmallVectorImpl &resultVals, int64_t VL)> + bodyBuilderFn); + void yield(mlir::ValueRange iterArgs) const; void copyToBuffer( @@ -130,8 +168,8 @@ struct KrnlBuilder : public DialectBuilder { mlir::ValueRange bStart, mlir::Value C, mlir::ValueRange cStart, // Loops are the krnl loop indices that this matmul replaces mlir::ValueRange loops, - // the computeStarts indicate the i/j/k indices pointing to the beginning - // of the matmul computation. + // the computeStarts indicate the i/j/k indices pointing to the + // beginning of the matmul computation. mlir::ValueRange computeStarts, // The globalUBs are the global bounds on the original I, J, K // dimensions. @@ -193,8 +231,9 @@ struct KrnlBuilder : public DialectBuilder { // We use here a Affine builder that generates Krnl Load and Store ops instead // of the affine memory ops directly. This is because we can still generate -// Krnl Ops while lowering the dialect, and the big advantage of the Krnl memory -// operations is that they distinguish themselves if they are affine or not. +// Krnl Ops while lowering the dialect, and the big advantage of the Krnl +// memory operations is that they distinguish themselves if they are affine or +// not. using AffineBuilderKrnlMem = GenericAffineBuilder; @@ -254,3 +293,4 @@ struct MultiDialectBuilder }; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/Krnl/KrnlHelper.hpp b/src/Dialect/Krnl/KrnlHelper.hpp index 600c920689..07cfb05bc6 100644 --- a/src/Dialect/Krnl/KrnlHelper.hpp +++ b/src/Dialect/Krnl/KrnlHelper.hpp @@ -4,7 +4,7 @@ //====---------------- KrnlHelper.hpp - Krnl Dialect Helper----------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_KRNL_HELPER_H +#define ONNX_MLIR_KRNL_HELPER_H #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -100,3 +101,4 @@ bool isKrnlGlobalConstant(mlir::Value result); } // namespace krnl } // namespace onnx_mlir +#endif diff --git a/src/Dialect/Krnl/KrnlOps.cpp b/src/Dialect/Krnl/KrnlOps.cpp index 2edb02154d..fbe8c34765 100644 --- a/src/Dialect/Krnl/KrnlOps.cpp +++ b/src/Dialect/Krnl/KrnlOps.cpp @@ -839,13 +839,9 @@ uint64_t KrnlGlobalOp::getBufferSize() { return affine::getIntOrFloatMemRefSizeInBytes(memRefTy).value(); } -void KrnlGlobalOp::setBuffer(ArrayRef rawData) { - return; -} - -void KrnlGlobalOp::freeBuffer(ArrayRef rawData) { - return; -} +void KrnlGlobalOp::setBuffer(ArrayRef rawData) { return; } + +void KrnlGlobalOp::freeBuffer(ArrayRef rawData) { return; } //===----------------------------------------------------------------------===// // KrnlMatMulOp diff --git a/src/Dialect/Krnl/KrnlOps.hpp b/src/Dialect/Krnl/KrnlOps.hpp index 30d6b16d6d..854d0fbf18 100644 --- a/src/Dialect/Krnl/KrnlOps.hpp +++ b/src/Dialect/Krnl/KrnlOps.hpp @@ -4,7 +4,7 @@ //===--------------------- KrnlOps.hpp - Krnl Operations ------------------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_KRNL_H +#define ONNX_MLIR_KRNL_H #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" @@ -27,3 +28,4 @@ #define GET_OP_CLASSES #include "src/Dialect/Krnl/KrnlOps.hpp.inc" +#endif diff --git a/src/Dialect/Krnl/KrnlTypes.hpp b/src/Dialect/Krnl/KrnlTypes.hpp index fa9a8e84d7..2766dbf357 100644 --- a/src/Dialect/Krnl/KrnlTypes.hpp +++ b/src/Dialect/Krnl/KrnlTypes.hpp @@ -4,7 +4,7 @@ //===------------------- KrnlTypes.hpp - Krnl Operations ------------------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_KRNL_TYPES_H +#define ONNX_MLIR_KRNL_TYPES_H #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -79,3 +80,4 @@ void customizeTypeConverter(mlir::LLVMTypeConverter &typeConverter); } // namespace krnl } // namespace onnx_mlir +#endif diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index 601d9f8b01..d04c2a8b83 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -51,74 +51,161 @@ namespace onnx_mlir { // ONNX Integers as MLIR signless, and only flag the ONNX Unsigned Integer as // MLIR unsigned integer. +/* static */ bool MathBuilder::isVector(Value val) { + return isVector(val.getType()); +} + /* static */ bool MathBuilder::isVector(Type type) { return mlir::dyn_cast(type) != nullptr; } -/* static */ Type MathBuilder::elementTypeWithVector(Type elementOrVectorType) { +/* static */ Type MathBuilder::elementTypeOfScalarOrVector(Value val) { + return elementTypeOfScalarOrVector(val.getType()); +} + +/* static */ Type MathBuilder::elementTypeOfScalarOrVector( + Type elementOrVectorType) { VectorType vectorType = mlir::dyn_cast(elementOrVectorType); if (vectorType) return vectorType.getElementType(); return elementOrVectorType; } +// return a vector of "elementType" with the same vector shape as "vectorType" /* static */ Type MathBuilder::getTypeWithVector( - VectorType vectorType, Type elementType) { - if (vectorType) - return VectorType::get(vectorType.getShape(), elementType); - return elementType; + Type vectorType, Type elementType) { + assert(!isVector(elementType) && "element type expected to be a scalar"); + // When vectorType is not a vector, then we need to return a scalar of the + // type elementType. + if (!isVector(vectorType)) + return elementType; + // When vectorType is actually a vector, then replicate the shape of + // vectorType with the element type of elementType. + return VectorType::get( + mlir::cast(vectorType).getShape(), elementType); } -/* static */ bool MathBuilder::isIntegerWithVector(Type elementOrVectorType) { - Type elementType = elementTypeWithVector(elementOrVectorType); +/* static */ bool MathBuilder::isScalarOrVectorInteger(Value val) { + return isScalarOrVectorInteger(val.getType()); +} + +/* static */ bool MathBuilder::isScalarOrVectorInteger( + Type elementOrVectorType) { + Type elementType = elementTypeOfScalarOrVector(elementOrVectorType); return mlir::isa(elementType) || mlir::isa(elementType); } -/* static */ bool MathBuilder::isUnsignedIntegerWithVector( +/* static */ bool MathBuilder::isScalarOrVectorUnsignedInteger(Value val) { + return isScalarOrVectorUnsignedInteger(val.getType()); +} + +/* static */ bool MathBuilder::isScalarOrVectorUnsignedInteger( Type elementOrVectorType) { - Type elementType = elementTypeWithVector(elementOrVectorType); + Type elementType = elementTypeOfScalarOrVector(elementOrVectorType); return elementType.isUnsignedInteger(); } -/* static */ bool MathBuilder::isFloatWithVector(Type elementOrVectorType) { - Type elementType = elementTypeWithVector(elementOrVectorType); +/* static */ bool MathBuilder::isScalarOrVectorFloat(Value val) { + return isScalarOrVectorFloat(val.getType()); +} + +/* static */ bool MathBuilder::isScalarOrVectorFloat(Type elementOrVectorType) { + Type elementType = elementTypeOfScalarOrVector(elementOrVectorType); return mlir::isa(elementType); } +bool MathBuilder::splatToMatch(Value &first, Value &second) const { + Type firstType = first.getType(); + Type secondType = second.getType(); + VectorType firstVectorType = mlir::dyn_cast(firstType); + VectorType secondVectorType = mlir::dyn_cast(secondType); + MultiDialectBuilder create(*this); + LLVM_DEBUG(llvm::dbgs() << "Splat to match first: " << firstType << "\n"; + llvm::dbgs() << " second: " << secondType << "\n";); + + // Splat first if needed. + if (!firstVectorType && secondVectorType) { + firstVectorType = VectorType::get(secondVectorType.getShape(), firstType); + first = create.vec.splat(firstVectorType, first); + LLVM_DEBUG(llvm::dbgs() << " splat first\n"); + return true; + } + // Splat second if needed. + if (firstVectorType && !secondVectorType) { + secondVectorType = VectorType::get(firstVectorType.getShape(), secondType); + second = create.vec.splat(secondVectorType, second); + LLVM_DEBUG(llvm::dbgs() << " splat second\n"); + return true; + } + // Otherwise check compatibility. + assert(create.vec.compatibleShapes(firstType, secondType) && + "expected compatible shapes"); + return false; +} + +bool MathBuilder::splatToMatch( + Value &first, Value &second, Value &third) const { + bool changeIn12 = splatToMatch(first, second); + bool changeIn13 = splatToMatch(first, third); + if (!changeIn12 && changeIn13) + // Have missed changes in 1-2 pair, redo. + splatToMatch(first, second); + return changeIn12 || changeIn13; +} + +void MathBuilder::splatToMatch(llvm::SmallVectorImpl &vals) const { + // Do not check the types when matching splats as this interface is called + // blindly on a list of vals. + int64_t size = vals.size(); + if (size <= 1) + return; // Nothing to do with 0 or 1 values. + if (size == 2) { + splatToMatch(vals[0], vals[1]); + } else if (size == 3) { + splatToMatch(vals[0], vals[1], vals[2]); + } else { + llvm_unreachable("can only splat to match up to 3 values"); + } +} + Value MathBuilder::abs(Value val) const { - if (isIntegerWithVector(val.getType())) + if (isScalarOrVectorInteger(val)) return b().create(loc(), val); - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected int or float"); } Value MathBuilder::andi(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int"); } Value MathBuilder::ori(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int"); } Value MathBuilder::xori(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int"); } Value MathBuilder::add(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) { - Type elemType = elementTypeWithVector(lhs.getType()); + if (isScalarOrVectorInteger(lhs)) { + Type elemType = elementTypeOfScalarOrVector(lhs); if (elemType.isUnsignedInteger()) { unsigned elemWidth = mlir::cast(elemType).getWidth(); Value castLhs = castToSignless(lhs, elemWidth); @@ -129,24 +216,26 @@ Value MathBuilder::add(Value lhs, Value rhs) const { } else return b().create(loc(), lhs, rhs); } - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int or float"); } Value MathBuilder::sub(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int or float"); } Value MathBuilder::mul(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) { - Type elemType = elementTypeWithVector(lhs.getType()); + if (isScalarOrVectorInteger(lhs)) { + Type elemType = elementTypeOfScalarOrVector(lhs); if (elemType.isUnsignedInteger()) { unsigned elemWidth = mlir::cast(elemType).getWidth(); Value castLhs = castToSignless(lhs, elemWidth); @@ -157,66 +246,125 @@ Value MathBuilder::mul(Value lhs, Value rhs) const { } else return b().create(loc(), lhs, rhs); } - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int or float"); } Value MathBuilder::div(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return b().create(loc(), lhs, rhs); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) return b().create(loc(), lhs, rhs); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int or float"); } Value MathBuilder::rem(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return b().create(loc(), lhs, rhs); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) return b().create(loc(), lhs, rhs); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int or float"); } +Value MathBuilder::round(Value x) const { + Type type = x.getType(); + assert(isScalarOrVectorFloat(type) && "expected float"); + // Use algorithm originally posted in ONNXtoKRNL/Math/Elementwise.cpp + // lowering. + + // Use numpy algorithm for rint as follows. + // ``` + // double y, r; + // y = npy_floor(x); + // r = x - y; + // + // if (r > 0.5) { + // y += 1.0; + // } + // + // /* Round to nearest even */ + // if (r == 0.5) { + // r = y - 2.0*npy_floor(0.5*y); + // if (r == 1.0) { + // y += 1.0; + // } + // } + // return y; + // ``` + Value one = constant(type, 1.0); + Value two = constant(type, 2.0); + Value half = constant(type, 0.5); + Value y = floor(x); + Value r = sub(x, y); + // r > 0.5 + Value rGreaterThanHalf = sgt(r, half); + Value y1 = select(rGreaterThanHalf, add(y, one), y); + // r == 0.5: round to nearest even. + Value y2 = mul(half, y); + y2 = floor(y2); + y2 = mul(y2, two); + Value rr = sub(y, y2); + Value rrEqualOne = eq(rr, one); + y2 = select(rrEqualOne, add(y, one), y); + + Value rEqualHalf = eq(r, half); + return select(rEqualHalf, y2, y1); +} + Value MathBuilder::copySign(mlir::Value rem, mlir::Value dividend) const { + splatToMatch(rem, dividend); assert(rem.getType() == dividend.getType() && "expected same type"); - if (isFloatWithVector(rem.getType())) + if (isScalarOrVectorFloat(rem)) return b().create(loc(), rem, dividend); llvm_unreachable("expected float"); } Value MathBuilder::ceilDiv(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) return b().create(loc(), lhs, rhs); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int"); } +Value MathBuilder::clip(Value val, Value lb, Value ub) const { + // Don't perform type assert and/or splats as it will be done in the min/max + // operations. + val = max(val, lb); // Clip lower range. + return min(val, ub); // Clip upper range. +} + Value MathBuilder::floorDiv(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) // Using regular unsigned div is ok as it rounds toward zero. return b().create(loc(), lhs, rhs); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int"); } // return (lhs * rhs) + acc Value MathBuilder::fma(Value lhs, Value rhs, Value acc) const { + splatToMatch(lhs, rhs, acc); assert((lhs.getType() == rhs.getType()) && (rhs.getType() == acc.getType()) && "expected same type"); - if (isFloatWithVector(lhs.getType()) && !isa(lhs.getType())) + if (isScalarOrVectorFloat(lhs) && isVector(lhs)) { return b().create(loc(), lhs, rhs, acc); - return add(mul(lhs, rhs), acc); + } + return add(mul(lhs, rhs), acc); // Handle broadcast there. } Value MathBuilder::erf(Value val) const { @@ -224,197 +372,215 @@ Value MathBuilder::erf(Value val) const { } Value MathBuilder::exp(Value val) const { - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected float"); } Value MathBuilder::exp2(Value val) const { - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected float"); } Value MathBuilder::log(Value val) const { - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected float"); } Value MathBuilder::log2(Value val) const { - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected float"); } Value MathBuilder::sqrt(Value val) const { - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected float"); } Value MathBuilder::pow(Value base, Value exp) const { - if (isFloatWithVector(base.getType())) + splatToMatch(base, exp); + if (isScalarOrVectorFloat(base)) return b().create(loc(), base, exp); llvm_unreachable("expected base float"); } Value MathBuilder::neg(Value val) const { - if (isIntegerWithVector(val.getType())) + if (isScalarOrVectorInteger(val)) // Returns 0 - val. return sub(constant(val.getType(), 0), val); - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected int or float"); } Value MathBuilder::ceil(Value val) const { - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected float"); } Value MathBuilder::floor(Value val) const { - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected float"); } Value MathBuilder::tanh(Value val) const { - if (isFloatWithVector(val.getType())) + if (isScalarOrVectorFloat(val)) return b().create(loc(), val); llvm_unreachable("expected float"); } Value MathBuilder::min(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return b().create(loc(), lhs, rhs); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) return b().create(loc(), lhs, rhs); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int or float"); } Value MathBuilder::max(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return b().create(loc(), lhs, rhs); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) return b().create(loc(), lhs, rhs); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return b().create(loc(), lhs, rhs); llvm_unreachable("expected int or float"); } Value MathBuilder::sgt(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::sgt); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return createArithCmp(lhs, rhs, arith::CmpFPredicate::OGT); llvm_unreachable("expected int or float"); } Value MathBuilder::sge(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::sge); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return createArithCmp(lhs, rhs, arith::CmpFPredicate::OGE); llvm_unreachable("expected int or float"); } Value MathBuilder::slt(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::slt); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return createArithCmp(lhs, rhs, arith::CmpFPredicate::OLT); llvm_unreachable("expected int or float"); } Value MathBuilder::sle(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::sle); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return createArithCmp(lhs, rhs, arith::CmpFPredicate::OLE); llvm_unreachable("expected int or float"); } Value MathBuilder::ugt(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::ugt); llvm_unreachable("expected unsigned int"); } Value MathBuilder::uge(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::uge); llvm_unreachable("expected unsigned int"); } Value MathBuilder::ult(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::ult); llvm_unreachable("expected unsigned int"); } Value MathBuilder::ule(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isUnsignedIntegerWithVector(lhs.getType())) + if (isScalarOrVectorUnsignedInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::ule); llvm_unreachable("expected unsigned int"); } Value MathBuilder::gt(Value lhs, Value rhs) const { - if (isUnsignedIntegerWithVector(lhs.getType())) + splatToMatch(lhs, rhs); + if (isScalarOrVectorUnsignedInteger(lhs)) return ugt(lhs, rhs); return sgt(lhs, rhs); } Value MathBuilder::ge(Value lhs, Value rhs) const { - if (isUnsignedIntegerWithVector(lhs.getType())) + splatToMatch(lhs, rhs); + if (isScalarOrVectorUnsignedInteger(lhs)) return uge(lhs, rhs); return sge(lhs, rhs); } Value MathBuilder::lt(Value lhs, Value rhs) const { - if (isUnsignedIntegerWithVector(lhs.getType())) + splatToMatch(lhs, rhs); + if (isScalarOrVectorUnsignedInteger(lhs)) return ult(lhs, rhs); return slt(lhs, rhs); } Value MathBuilder::le(Value lhs, Value rhs) const { - if (isUnsignedIntegerWithVector(lhs.getType())) + splatToMatch(lhs, rhs); + if (isScalarOrVectorUnsignedInteger(lhs)) return ule(lhs, rhs); return sle(lhs, rhs); } Value MathBuilder::eq(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::eq); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return createArithCmp(lhs, rhs, arith::CmpFPredicate::OEQ); llvm_unreachable("expected int or float"); } Value MathBuilder::neq(Value lhs, Value rhs) const { + splatToMatch(lhs, rhs); assert(lhs.getType() == rhs.getType() && "expected same type"); - if (isIntegerWithVector(lhs.getType())) + if (isScalarOrVectorInteger(lhs)) return createArithCmp(lhs, rhs, arith::CmpIPredicate::ne); - if (isFloatWithVector(lhs.getType())) + if (isScalarOrVectorFloat(lhs)) return createArithCmp(lhs, rhs, arith::CmpFPredicate::ONE); llvm_unreachable("expected int or float"); } Value MathBuilder::select(Value cmp, Value trueVal, Value falseVal) const { + splatToMatch(cmp, trueVal, falseVal); assert(trueVal.getType() == falseVal.getType() && "expected same type"); return b().create(loc(), cmp, trueVal, falseVal); } @@ -422,7 +588,7 @@ Value MathBuilder::select(Value cmp, Value trueVal, Value falseVal) const { Value MathBuilder::constant(Type type, double val) const { Value constant = nullptr; // Could be a vector type; look at the element type. - Type elementType = elementTypeWithVector(type); + Type elementType = elementTypeOfScalarOrVector(type); TypeSwitch(elementType) .Case([&](Type) { constant = @@ -569,7 +735,7 @@ TypedAttr MathBuilder::positiveInfAttr(mlir::Type type) const { Value MathBuilder::negativeInf(Type type) const { // Strip vector type if any. - Type elementType = elementTypeWithVector(type); + Type elementType = elementTypeOfScalarOrVector(type); TypedAttr attr = negativeInfAttr(elementType); Value constant = b().create(loc(), attr); assert(constant != nullptr && "Expecting valid constant value"); @@ -584,7 +750,7 @@ Value MathBuilder::negativeInf(Type type) const { Value MathBuilder::positiveInf(Type type) const { // Strip vector type if any. - Type elementType = elementTypeWithVector(type); + Type elementType = elementTypeOfScalarOrVector(type); TypedAttr attr = positiveInfAttr(elementType); Value constant = b().create(loc(), attr); assert(constant != nullptr && "Expecting valid constant value"); @@ -601,7 +767,7 @@ Value MathBuilder::createArithCmp( Value lhs, Value rhs, arith::CmpIPredicate pred) const { Type type = lhs.getType(); assert(type == rhs.getType() && "Operands should have the same type"); - assert(isIntegerWithVector(type) && "expected int"); + assert(isScalarOrVectorInteger(type) && "expected int"); return b().create(loc(), pred, lhs, rhs); } @@ -609,7 +775,7 @@ Value MathBuilder::createArithCmp( Value lhs, Value rhs, arith::CmpFPredicate pred) const { Type type = lhs.getType(); assert(type == rhs.getType() && "Operands should have the same type"); - assert(isFloatWithVector(type) && "expected float"); + assert(isScalarOrVectorFloat(type) && "expected float"); return b().create(loc(), pred, lhs, rhs); } @@ -618,11 +784,10 @@ Value MathBuilder::createArithCmp( // best of my understanding. Value MathBuilder::castToSignless(Value val, int64_t width) const { Type valType = val.getType(); - VectorType vecType = mlir::dyn_cast(valType); - Type valElemType = elementTypeWithVector(valType); + Type valElemType = elementTypeOfScalarOrVector(valType); assert(mlir::isa(valElemType) && !valElemType.isSignlessInteger() && "Expecting signed integer type"); - Type destType = getTypeWithVector(vecType, b().getIntegerType(width)); + Type destType = getTypeWithVector(valType, b().getIntegerType(width)); return b() .create(loc(), destType, val) .getResult(0); @@ -630,11 +795,10 @@ Value MathBuilder::castToSignless(Value val, int64_t width) const { Value MathBuilder::castToUnsigned(Value val, int64_t width) const { Type valType = val.getType(); - VectorType vecType = mlir::dyn_cast(valType); - Type valElemType = elementTypeWithVector(valType); + Type valElemType = elementTypeOfScalarOrVector(valType); assert(mlir::isa(valElemType) && "Expecting integer type"); Type destType = - getTypeWithVector(vecType, b().getIntegerType(width, false /*signed*/)); + getTypeWithVector(valType, b().getIntegerType(width, false /*signed*/)); return b() .create(loc(), destType, val) .getResult(0); @@ -642,25 +806,55 @@ Value MathBuilder::castToUnsigned(Value val, int64_t width) const { // Methods inspired from MLIR TosaToLinalg CastOp. Value MathBuilder::cast(Type destType, Value src) const { - // Get element type and vector types (if any, i.e. possibly nullptr). Type srcType = src.getType(); - VectorType srcVecType = mlir::dyn_cast(srcType); - VectorType destVecType = mlir::dyn_cast(destType); - Type srcElemType = elementTypeWithVector(srcType); - Type destElemType = elementTypeWithVector(destType); - // Make sure we don't mix vector and scalars. - assert(((srcVecType && destVecType) || (!srcVecType && !destVecType)) && - "expect both to be scalars or vectors"); // Check if we even need a cast. if (srcType == destType) return src; + // Get element type and vector types (if any, i.e. possibly nullptr). + + /////////////////////////////////////////////////////////////////////// + // WARNING: do not confuse (src|dest) ElemType and (src|dest) Type! + // + // ElemTypes and Types are the same for scalar BUT NOT for vector inputs. + // For vectors inputs, Types are vector and ElemTypes + // are the element type associated with the vector. + // + // When testing for properties (is int, float,...): use ElemTypes. + // When creating ops, use Types for types to translate to, as if we have a + // scalar input, we need a scalar output; and if we have a vector input, then + // we need a vector output. + /////////////////////////////////////////////////////////////////////// + + Type srcElemType = elementTypeOfScalarOrVector(srcType); + Type destElemType = elementTypeOfScalarOrVector(destType); + VectorType srcVecType = mlir::dyn_cast(srcType); + VectorType destVecType = mlir::dyn_cast(destType); + assert(VectorBuilder::compatibleShapes(srcType, destType) && + "expected compatible vector shape (if any)"); + + // Handling of special cases for vectors. + if (destVecType && !srcVecType) { + // When the destination type is requested to be a vector type, but the input + // is not, then perform a scalar cast first, and then splat the output. + Value scalarCastVal = cast(destElemType, src); + MultiDialectBuilder create(*this); + return create.vec.splat(destVecType, scalarCastVal); + } + if (srcVecType && !destVecType) { + // When the source (to be cast) is a vector, but the destination type is + // not, then just transform the destination type to a vector of the same + // shape as srcType and the elementType of destType. + destType = getTypeWithVector(srcType, destElemType); + assert(destElemType == elementTypeOfScalarOrVector(destType) && + "correctness check"); + } // Process index types first. if (mlir::isa(srcElemType)) { // If the source is an index type, first convert it into a signless int of // size 64. srcElemType = b().getIntegerType(64); - srcType = getTypeWithVector(srcVecType, srcElemType); + srcType = getTypeWithVector(srcType, srcElemType); src = b().create(loc(), srcType, src); } bool destIsIndex = false; @@ -669,7 +863,7 @@ Value MathBuilder::cast(Type destType, Value src) const { // If the dest is an index type, pretend for now that we want it to be // converted to signless int of size 64. destElemType = b().getIntegerType(64); - destType = getTypeWithVector(destVecType, destElemType); + destType = getTypeWithVector(destType, destElemType); destIsIndex = true; } @@ -712,7 +906,7 @@ Value MathBuilder::cast(Type destType, Value src) const { // An integer constant must be signless. unsigned srcElemWidth = mlir::cast(srcElemType).getWidth(); constantType = getTypeWithVector( - srcVecType, IntegerType::get(srcElemType.getContext(), srcElemWidth)); + srcType, IntegerType::get(srcElemType.getContext(), srcElemWidth)); src = castToSignless(src, srcElemWidth); } Value zero = constant(constantType, 0); @@ -733,8 +927,9 @@ Value MathBuilder::cast(Type destType, Value src) const { mlir::isa(destElemType)) { // TosaToLinalg in MLIR uses a fancier algorithm that clamps values to // min/max signed/unsigned integer values. - if (destType.isUnsignedInteger()) { - Type castType = b().getIntegerType(destElemWidth); + if (destElemType.isUnsignedInteger()) { + Type castElementType = b().getIntegerType(destElemWidth); + Type castType = getTypeWithVector(destType, castElementType); Value cast = b().create(loc(), castType, src); return castToUnsigned(cast, destElemWidth); } else { @@ -759,24 +954,26 @@ Value MathBuilder::cast(Type destType, Value src) const { } // Int to int conversion. - if (mlir::isa(srcType) && mlir::isa(destType)) { - if (srcType.isUnsignedInteger()) { + if (mlir::isa(srcElemType) && + mlir::isa(destElemType)) { + if (srcElemType.isUnsignedInteger()) { // Unsigned to unsigned/signed conversion. // Same bit width for unsigned to signed conversion. - if ((srcElemWidth == destElemWidth) && destType.isSignlessInteger()) + if ((srcElemWidth == destElemWidth) && destElemType.isSignlessInteger()) return castToSignless(src, srcElemWidth); // Different bit width. assert((bitExtend || bitTrunc) && "expected extend or trunc"); // Has to convert to signless first, and reconvert output to unsigned. Value cast = castToSignless(src, srcElemWidth); - Type castType = b().getIntegerType(destElemWidth); + Type castElemType = b().getIntegerType(destElemWidth); + Type castType = getTypeWithVector(destType, castElemType); if (bitExtend) { cast = b().create(loc(), castType, cast); } else { // TosaToLinalg use a clipping algo, not sure if needed. cast = b().create(loc(), castType, cast); } - if (destType.isUnsignedInteger()) { + if (destElemType.isUnsignedInteger()) { // Unsigned to unsigned conversion. return castToUnsigned(cast, destElemWidth); } else { @@ -787,7 +984,7 @@ Value MathBuilder::cast(Type destType, Value src) const { // Signed to unsigned/signed conversion. // Handle signed integer // Same bit width for signed to unsigned conversion. - if ((srcElemWidth == destElemWidth) && destType.isUnsignedInteger()) + if ((srcElemWidth == destElemWidth) && destElemType.isUnsignedInteger()) return castToUnsigned(src, srcElemWidth); // Different bit width. Value dest = src; @@ -798,7 +995,7 @@ Value MathBuilder::cast(Type destType, Value src) const { dest = b().create(loc(), destType, src); if (destIsIndex) return b().create(loc(), b().getIndexType(), dest); - if (destType.isUnsignedInteger()) { + if (destElemType.isUnsignedInteger()) { return castToUnsigned(dest, destElemWidth); } else { return dest; @@ -886,6 +1083,7 @@ IntegerAttr MemRefBuilder::computeAlignment(int64_t alignment) const { // Alloc calls need a list of values, only for the dynamic shapes. Extract these // values from the list of index expressions that represent the shape of the // memref. + void MemRefBuilder::computeDynSymbols(MemRefType type, llvm::SmallVectorImpl &dims, llvm::SmallVectorImpl &dynSymbols) const { @@ -981,6 +1179,47 @@ memref::AllocOp MemRefBuilder::alignedAlloc(MemRefType type, //===----------------------------------------------------------------------===// // Info about memory size. +// Compute static size of memref in elements. Return true if has +// static size. +/*static*/ bool MemRefBuilder::getStaticMemSize( + MemRefType type, int64_t &staticSize, int64_t range) { + Type elementType = type.getElementType(); + assert(!(mlir::isa(elementType)) && "unsupported vector type"); + ArrayRef shape = type.getShape(); + staticSize = 1; // Multiplication of static sizes. + bool staticShape = true; // Static until proven otherwise. + int64_t rank = type.getRank(); + // Process with range [lb inclusive, ub exclusive) + int64_t lb = 0, ub = rank; + if (range == 0) + // Empty range, nothing to do. + return staticShape; + if (range > 0) { + // Positive range r: interval is [ 0, min(r, rank) ). + ub = (range < rank) ? range : rank; + } else { + // Negative range r: interval is [ max(0, r+rank) to rank ). + range += rank; + lb = range > 0 ? range : 0; + } + assert(lb >= 0 && ub <= rank && "out of bound range"); + for (int64_t i = 0; i < rank; ++i) { + if (shape[i] == ShapedType::kDynamic) { + if (i >= lb && i < ub) { + // Keep track of static shape and dynamic sizes only when inbounds. + staticShape = false; + } + } else { + // Has constant shape. + if (i >= lb && i < ub) { + // Keep track of static size only when inbounds. + staticSize *= shape[i]; + } + } + } + return staticShape; +} + // Compute static and dynamic size of memref in elements. Return true if has // static size. bool MemRefBuilder::getStaticAndDynamicMemSize(MemRefType type, @@ -1041,9 +1280,9 @@ bool MemRefBuilder::getStaticAndDynamicMemSize(MemRefType type, // Alloc functions with alignment and padding for SIMD Value MemRefBuilder::alignedAllocWithSimdPadding( - mlir::MemRefType type, int64_t simdUnroll, int64_t alignment) const { + mlir::MemRefType type, int64_t VL, int64_t alignment) const { llvm::SmallVector dynSymbols; - return alignedAllocWithSimdPadding(type, dynSymbols, simdUnroll, alignment); + return alignedAllocWithSimdPadding(type, dynSymbols, VL, alignment); } Value MemRefBuilder::alignedAllocWithSimdPadding(MemRefType type, @@ -1480,21 +1719,47 @@ void SCFBuilder::yield() const { b().create(loc()); } // Vector Builder //===----------------------------------------------------------------------===// -int64_t VectorBuilder::getMachineVectorLength(const Type &elementType) const { - VectorMachineSupport *vms = - VectorMachineSupport::getGlobalVectorMachineSupport(); +/*static*/ bool VectorBuilder::compatibleShapes(const Type t1, const Type t2) { + // If both are vectors, check that the shapes are identical. + VectorType vt1 = mlir::dyn_cast(t1); + VectorType vt2 = mlir::dyn_cast(t2); + if (vt1 && vt2) { + auto shape1 = vt1.getShape(); + auto shape2 = vt2.getShape(); + // Different rank, return false. + if (shape1.size() != shape2.size()) + return false; + for (int64_t i = 0; i < (int64_t)shape1.size(); ++i) + if (shape1[i] != shape2[i]) + return false; + // Same dim and shapes + return true; + } + // Neither is a vector (no shape tests) or only one is a vector (and the other + // one can thus be broadcasted to it), we have compatible shapes. + return true; +} + +/*static*/ bool VectorBuilder::compatibleTypes(const Type t1, const Type t2) { + Type e1 = MathBuilder::elementTypeOfScalarOrVector(t1); + Type e2 = MathBuilder::elementTypeOfScalarOrVector(t2); + return (e1 == e2) && compatibleShapes(t1, t2); +} + +int64_t VectorBuilder::getArchVectorLength(const Type &elementType) const { // Even if unsupported, we can always compute one result per vector. - return std::max((int64_t)1, vms->getVectorLength(elementType)); + return std::max( + (int64_t)1, VectorMachineSupport::getArchVectorLength(elementType)); } -int64_t VectorBuilder::getMachineVectorLength(const VectorType &vecType) const { - return getMachineVectorLength(vecType.getElementType()); +int64_t VectorBuilder::getArchVectorLength(const VectorType &vecType) const { + return getArchVectorLength(vecType.getElementType()); } -int64_t VectorBuilder::getMachineVectorLength(Value vecValue) const { +int64_t VectorBuilder::getArchVectorLength(Value vecValue) const { VectorType vecType = mlir::dyn_cast_or_null(vecValue.getType()); assert(vecType && "expected vector type"); - return getMachineVectorLength(vecType.getElementType()); + return getArchVectorLength(vecType.getElementType()); } Value VectorBuilder::load( @@ -1635,43 +1900,43 @@ Value VectorBuilder::reduction( loc(), vector::CombiningKind::MUL, value); } case CombiningKind::MAX: { - if (MathBuilder::isUnsignedIntegerWithVector(type)) + if (MathBuilder::isScalarOrVectorUnsignedInteger(type)) return b().create( loc(), vector::CombiningKind::MAXUI, value); - if (MathBuilder::isIntegerWithVector(type)) + if (MathBuilder::isScalarOrVectorInteger(type)) return b().create( loc(), vector::CombiningKind::MAXSI, value); - if (MathBuilder::isFloatWithVector(type)) + if (MathBuilder::isScalarOrVectorFloat(type)) return b().create( loc(), vector::CombiningKind::MAXNUMF, value); llvm_unreachable("unknown type in max"); } case CombiningKind::MIN: { - if (MathBuilder::isUnsignedIntegerWithVector(type)) + if (MathBuilder::isScalarOrVectorUnsignedInteger(type)) return b().create( loc(), vector::CombiningKind::MINUI, value); - if (MathBuilder::isIntegerWithVector(type)) + if (MathBuilder::isScalarOrVectorInteger(type)) return b().create( loc(), vector::CombiningKind::MINSI, value); - if (MathBuilder::isFloatWithVector(type)) + if (MathBuilder::isScalarOrVectorFloat(type)) return b().create( loc(), vector::CombiningKind::MINNUMF, value); llvm_unreachable("unknown type in min"); } case CombiningKind::AND: { - if (MathBuilder::isIntegerWithVector(type)) + if (MathBuilder::isScalarOrVectorInteger(type)) return b().create( loc(), vector::CombiningKind::AND, value); llvm_unreachable("unknown type in and"); } case CombiningKind::OR: { - if (MathBuilder::isIntegerWithVector(type)) + if (MathBuilder::isScalarOrVectorInteger(type)) return b().create( loc(), vector::CombiningKind::OR, value); llvm_unreachable("unknown type in or"); } case CombiningKind::XOR: { - if (MathBuilder::isIntegerWithVector(type)) + if (MathBuilder::isScalarOrVectorInteger(type)) return b().create( loc(), vector::CombiningKind::XOR, value); llvm_unreachable("unknown type in xor"); @@ -1693,13 +1958,13 @@ void VectorBuilder::multiReduction(SmallVectorImpl &inputVecArray, uint64_t N = inputVecArray.size(); assert(N > 0 && "expected at least one value to reduce"); uint64_t VL = getLengthOf1DVector(inputVecArray[0]); - uint64_t machineVL = getMachineVectorLength(inputVecArray[0]); + uint64_t archVL = getArchVectorLength(inputVecArray[0]); // TODO alex, should relax this - assert(VL == machineVL && "only natural sizes supported at this time"); - assert(N % machineVL == 0 && + assert(VL == archVL && "only natural sizes supported at this time"); + assert(N % archVL == 0 && "can only reduces multiple of VL vectors at this time"); LLVM_DEBUG(llvm::dbgs() << "reduction with N " << N << ", VL " << VL - << ", mVL " << machineVL << "\n";); + << ", archVL " << archVL << "\n";); // Emplace all input vectors in a temporary array. SmallVector tmpArray; @@ -1713,14 +1978,14 @@ void VectorBuilder::multiReduction(SmallVectorImpl &inputVecArray, // Reductions of full physical vectors. outputVecArray.clear(); MultiDialectBuilder create(*this); - // Process each block of machineVL input vectors at a time. - for (uint64_t r = 0; r < N; r += machineVL) { + // Process each block of archVL input vectors at a time. + for (uint64_t r = 0; r < N; r += archVL) { // Algorithm for the set of input arrays from tmp[r] to - // tmp[r+machineVL-1]. - // With machineVL inputs, we have machineVL/2 initial pairs. - uint64_t numPairs = machineVL / 2; + // tmp[r+archVL-1]. + // With archVL inputs, we have archVL/2 initial pairs. + uint64_t numPairs = archVL / 2; // While we have pairs... - for (uint64_t step = 1; step < machineVL; step = step * 2) { + for (uint64_t step = 1; step < archVL; step = step * 2) { // For each pair, reduce pair 2p and 2p+1 and save sum into p. for (uint64_t p = 0; p < numPairs; ++p) { Value highVal = @@ -1732,64 +1997,11 @@ void VectorBuilder::multiReduction(SmallVectorImpl &inputVecArray, } numPairs = numPairs / 2; // Pair number decrease by power of 2. } - // Completed the machineVL x machineVL reduction, save it in the output. + // Completed the archVL x archVL reduction, save it in the output. outputVecArray.emplace_back(tmpArray[r]); } } -int64_t VectorBuilder::computeSuitableUnrollFactor(VectorMachineSupport *vms, - MemRefType memRefType, llvm::SmallVectorImpl &memRefDims, - int64_t collapsedInnermostLoops, int64_t maxSimdUnroll, bool canPad, - int64_t &simdLoopStaticTripCount) const { - assert(collapsedInnermostLoops > 0 && "expected at least one collapsed loop"); - assert(maxSimdUnroll > 0 && "expected positive max simd unroll"); - simdLoopStaticTripCount = 0; // Initially assume no SIMD. - Type elementType = memRefType.getElementType(); - int64_t VL = vms->getVectorLength(elementType); - LLVM_DEBUG(llvm::dbgs() << " simd hw VL is " << VL << "\n"); - if (VL == 0) { - LLVM_DEBUG(llvm::dbgs() << " simd disabled: no simd\n"); - return 0; - } - MemRefBuilder createMem(*this); - int64_t staticSize; - IndexExpr dynSize; - bool isStaticSize = createMem.getStaticAndDynamicMemSize( - memRefType, memRefDims, staticSize, dynSize, -collapsedInnermostLoops); - if (isStaticSize && staticSize < VL) { - LLVM_DEBUG(llvm::dbgs() << " simd disabled: trip count " << staticSize - << " too short for a VL of " << VL << "\n"); - return 0; - } - // Unless otherwise disabled, here is the estimated trip count. - simdLoopStaticTripCount = staticSize > 1 ? staticSize : -1; - if (canPad && collapsedInnermostLoops == (int64_t)memRefType.getRank()) { - // Fully collapsed and can add padding to be fine - return maxSimdUnroll * VL; - } - // We have a partially flattened operator. Since we do only simdize entire - // loops (i.e. we don't support scalar epilogues at this time), make sure - // the static size is a multiple of the VL. Get the VL of the store - // (output's element type). - if (staticSize % VL != 0) { - LLVM_DEBUG(llvm::dbgs() << " simd disabled: partial flattened dims " - << collapsedInnermostLoops << " with size " - << staticSize << " is not 0 mod VL " << VL << "\n"); - return 0; - } - // See if we can get a unroll factor. - for (int64_t u = maxSimdUnroll; u > 0; --u) { - if (staticSize % (u * VL) == 0) { - LLVM_DEBUG(llvm::dbgs() - << " partial flattened dims " << collapsedInnermostLoops - << " with size " << staticSize << " works with VL " << VL - << " and unroll " << u << "\n"); - return u * VL; - } - } - llvm_unreachable("should always find u==1 feasible"); -} - //===----------------------------------------------------------------------===// // LLVM Builder //===----------------------------------------------------------------------===// @@ -1821,12 +2033,33 @@ void LLVMBuilder::br(ArrayRef destOperands, Block *destBlock) const { b().create(loc(), destOperands, destBlock); } +void LLVMBuilder::handleVarArgCall(LLVM::CallOp &callOp, + ArrayRef resultTypes, ArrayRef inputs) const { + // Define result type (void or 1). + Type resultType; + if (resultTypes.size() == 0 || isa(resultTypes[0])) { + MLIRContext *ctx = b().getContext(); + resultType = LLVM::LLVMVoidType::get(ctx); + } else { + resultType = resultTypes[0]; + } + // Define input types. + llvm::SmallVector inputTypes; + for (int64_t i = 0; i < (int64_t)inputs.size(); ++i) + inputTypes.emplace_back(inputs[i].getType()); + auto typeSignature = + LLVM::LLVMFunctionType::get(resultType, inputTypes, /*is var arg*/ true); + callOp.setVarCalleeType(typeSignature); +} + Value LLVMBuilder::call(ArrayRef resultTypes, StringRef funcName, - ArrayRef inputs) const { + ArrayRef inputs, bool isVarArg) const { assert((resultTypes.size() == 0 || resultTypes.size() == 1) && "LLVM:CallOp must return either 0 or 1 value"); LLVM::CallOp callOp = b().create(loc(), resultTypes, funcName, inputs); + if (isVarArg) + handleVarArgCall(callOp, resultTypes, inputs); // CallOp may return either 0 or 1 value. if (resultTypes.empty()) return nullptr; @@ -1834,11 +2067,13 @@ Value LLVMBuilder::call(ArrayRef resultTypes, StringRef funcName, } Value LLVMBuilder::call(ArrayRef resultTypes, - FlatSymbolRefAttr funcSymbol, ArrayRef inputs) const { + FlatSymbolRefAttr funcSymbol, ArrayRef inputs, bool isVarArg) const { assert((resultTypes.size() == 0 || resultTypes.size() == 1) && "LLVM:CallOp must return either 0 or 1 value"); LLVM::CallOp callOp = b().create(loc(), resultTypes, funcSymbol, inputs); + if (isVarArg) + handleVarArgCall(callOp, resultTypes, inputs); // CallOp may return either 0 or 1 value. if (resultTypes.empty()) return nullptr; @@ -1923,8 +2158,8 @@ LLVM::LLVMFuncOp LLVMBuilder::func( return funcOp; // Create uniqueFuncOp if there exists a postfix. - // Since `funcOp` calls `uniqueFuncOp`, put `uniqueFuncOp`'s definition before - // `funcOp`. + // Since `funcOp` calls `uniqueFuncOp`, put `uniqueFuncOp`'s definition + // before `funcOp`. b().setInsertionPoint(funcOp); ModuleOp module = funcOp.getOperation()->getParentOfType(); std::string uniqueFuncName = diff --git a/src/Dialect/Mlir/DialectBuilder.hpp b/src/Dialect/Mlir/DialectBuilder.hpp index ba214b3944..26314c04f1 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp +++ b/src/Dialect/Mlir/DialectBuilder.hpp @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DIALECT_BUILDER_MLIR_H +#define ONNX_MLIR_DIALECT_BUILDER_MLIR_H #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -91,67 +92,82 @@ struct MathBuilder final : DialectBuilder { // Support for vectors: we provide queries that work regardless of if we have // (1) a scalar or (2) a vector of a basic element type. + static bool isVector(mlir::Value val); static bool isVector(mlir::Type type); // The method belows ignore the vectors part of the type to provide answer on // the basic element types alone. - static bool isIntegerWithVector(mlir::Type elementOrVectorType); - static bool isUnsignedIntegerWithVector(mlir::Type elementOrVectorType); - static bool isFloatWithVector(mlir::Type elementOrVectorType); + static bool isScalarOrVectorInteger(mlir::Value val); + static bool isScalarOrVectorInteger(mlir::Type elementOrVectorType); + static bool isScalarOrVectorUnsignedInteger(mlir::Value val); + static bool isScalarOrVectorUnsignedInteger(mlir::Type elementOrVectorType); + static bool isScalarOrVectorFloat(mlir::Value val); + static bool isScalarOrVectorFloat(mlir::Type elementOrVectorType); // Return the basic element type regardless of if we are given (1) a scalar or // (2) a vector of a basic element type. - static mlir::Type elementTypeWithVector(mlir::Type elementOrVectorType); + static mlir::Type elementTypeOfScalarOrVector(mlir::Value val); + static mlir::Type elementTypeOfScalarOrVector(mlir::Type elementOrVectorType); // Return a type of the same vector shape as vectorType with a basic element - // type of elementType. When vectorType is null, then the returned type is - // simply a scalar of elementType. + // type of elementType. When vectorType is not a vector, then the returned + // type is simply a scalar of elementType. ElementType should not be a scalar + // type. static mlir::Type getTypeWithVector( - mlir::VectorType vectorType, mlir::Type elementType); + mlir::Type vectorType, mlir::Type elementType); + + // "B" below indicates that the operation will splat scalar values if one of + // the input value is itself a vector. + + // "B" below indicates that the operation will splat scalar values if one of + // the input value is itself a vector. mlir::Value abs(mlir::Value val) const; - mlir::Value add(mlir::Value lhs, mlir::Value rhs) const; - mlir::Value andi(mlir::Value lhs, mlir::Value rhs) const; // Int only. - mlir::Value ceil(mlir::Value val) const; // Float only. - mlir::Value ceilDiv(mlir::Value lhs, mlir::Value rhs) const; // Int only. - mlir::Value copySign(mlir::Value rem, mlir::Value div) const; // Float only. - mlir::Value div(mlir::Value lhs, mlir::Value rhs) const; + mlir::Value add(mlir::Value lhs, mlir::Value rhs) const; // B. + mlir::Value andi(mlir::Value lhs, mlir::Value rhs) const; // B/Int only. + mlir::Value ceil(mlir::Value val) const; // Float only. + mlir::Value ceilDiv(mlir::Value lhs, mlir::Value rhs) const; // B/Int only. + mlir::Value clip(mlir::Value val, mlir::Value lb, mlir::Value ub) const; // B. + mlir::Value copySign(mlir::Value rem, mlir::Value div) const; // B/Float only. + mlir::Value div(mlir::Value lhs, mlir::Value rhs) const; // B. mlir::Value erf(mlir::Value val) const; mlir::Value exp(mlir::Value val) const; // Float only. mlir::Value exp2(mlir::Value val) const; // Float only. mlir::Value floor(mlir::Value val) const; // Float only. - mlir::Value floorDiv(mlir::Value lhs, mlir::Value rhs) const; // Int only. - mlir::Value fma(mlir::Value lhs, mlir::Value rhs, mlir::Value acc) const; - mlir::Value log(mlir::Value val) const; // Float only. - mlir::Value log2(mlir::Value val) const; // Float only. - mlir::Value mul(mlir::Value lhs, mlir::Value rhs) const; + mlir::Value floorDiv(mlir::Value lhs, mlir::Value rhs) const; // B/Int only. + mlir::Value fma( + mlir::Value lhs, mlir::Value rhs, mlir::Value acc) const; // B. + mlir::Value log(mlir::Value val) const; // Float only. + mlir::Value log2(mlir::Value val) const; // Float only. + mlir::Value mul(mlir::Value lhs, mlir::Value rhs) const; // B. mlir::Value neg(mlir::Value val) const; - mlir::Value ori(mlir::Value lhs, mlir::Value rhs) const; // Int only. - mlir::Value pow(mlir::Value base, mlir::Value exp) const; // Float only. - mlir::Value rem(mlir::Value lhs, mlir::Value rhs) const; - mlir::Value sqrt(mlir::Value val) const; // Float only. - mlir::Value sub(mlir::Value lhs, mlir::Value rhs) const; + mlir::Value ori(mlir::Value lhs, mlir::Value rhs) const; // B/Int only. + mlir::Value pow(mlir::Value base, mlir::Value exp) const; // B/Float only. + mlir::Value rem(mlir::Value lhs, mlir::Value rhs) const; // B. + mlir::Value round(mlir::Value) const; // Float only. + mlir::Value sqrt(mlir::Value val) const; // Float only. + mlir::Value sub(mlir::Value lhs, mlir::Value rhs) const; // B. mlir::Value tanh(mlir::Value val) const; // Float only. - mlir::Value xori(mlir::Value lhs, mlir::Value rhs) const; // Int only. + mlir::Value xori(mlir::Value lhs, mlir::Value rhs) const; // B/Int only. mlir::Value select( - mlir::Value cmp, mlir::Value trueVal, mlir::Value valseVal) const; - mlir::Value gt(mlir::Value lhs, mlir::Value rhs) const; - mlir::Value ge(mlir::Value lhs, mlir::Value rhs) const; - mlir::Value lt(mlir::Value lhs, mlir::Value rhs) const; - mlir::Value le(mlir::Value lhs, mlir::Value rhs) const; - mlir::Value eq(mlir::Value lhs, mlir::Value rhs) const; - mlir::Value neq(mlir::Value lhs, mlir::Value rhs) const; + mlir::Value cmp, mlir::Value trueVal, mlir::Value valseVal) const; // B. + mlir::Value gt(mlir::Value lhs, mlir::Value rhs) const; // B. + mlir::Value ge(mlir::Value lhs, mlir::Value rhs) const; // B. + mlir::Value lt(mlir::Value lhs, mlir::Value rhs) const; // B. + mlir::Value le(mlir::Value lhs, mlir::Value rhs) const; // B. + mlir::Value eq(mlir::Value lhs, mlir::Value rhs) const; // B. + mlir::Value neq(mlir::Value lhs, mlir::Value rhs) const; // B. // Signed versions (index/signless/signed int or float) - mlir::Value sgt(mlir::Value lhs, mlir::Value rhs) const; // No unsigned. - mlir::Value sge(mlir::Value lhs, mlir::Value rhs) const; // No unsigned. - mlir::Value slt(mlir::Value lhs, mlir::Value rhs) const; // No unsigned. - mlir::Value sle(mlir::Value lhs, mlir::Value rhs) const; // No unsigned. + mlir::Value sgt(mlir::Value lhs, mlir::Value rhs) const; // B/No unsigned. + mlir::Value sge(mlir::Value lhs, mlir::Value rhs) const; // B/No unsigned. + mlir::Value slt(mlir::Value lhs, mlir::Value rhs) const; // B/No unsigned. + mlir::Value sle(mlir::Value lhs, mlir::Value rhs) const; // B/No unsigned. // Unsigned versions - mlir::Value ugt(mlir::Value lhs, mlir::Value rhs) const; // Unsigned int only - mlir::Value uge(mlir::Value lhs, mlir::Value rhs) const; // Unsigned int only - mlir::Value ult(mlir::Value lhs, mlir::Value rhs) const; // Unsigned int only - mlir::Value ule(mlir::Value lhs, mlir::Value rhs) const; // Unsigned int only + mlir::Value ugt(mlir::Value lhs, mlir::Value rhs) const; // B/Unsigned only. + mlir::Value uge(mlir::Value lhs, mlir::Value rhs) const; // B/Unsigned only. + mlir::Value ult(mlir::Value lhs, mlir::Value rhs) const; // B/Unsigned only. + mlir::Value ule(mlir::Value lhs, mlir::Value rhs) const; // B/Unsigned only. - mlir::Value min(mlir::Value lhs, mlir::Value rhs) const; - mlir::Value max(mlir::Value lhs, mlir::Value rhs) const; + mlir::Value min(mlir::Value lhs, mlir::Value rhs) const; // B. + mlir::Value max(mlir::Value lhs, mlir::Value rhs) const; // B. mlir::Value constant(mlir::Type type, double val) const; mlir::Value constantIndex(int64_t val) const; @@ -172,7 +188,7 @@ struct MathBuilder final : DialectBuilder { // Cast handle bool/int/float/index elementary types. Do not convert // signed/index to unsigned. - mlir::Value cast(mlir::Type destType, mlir::Value val) const; + mlir::Value cast(mlir::Type destType, mlir::Value val) const; // B. mlir::Value castToIndex(mlir::Value val) const; // Add indexOffsets to the least significant indices. So if indices are (i, j, @@ -183,6 +199,8 @@ struct MathBuilder final : DialectBuilder { void addOffsetToLeastSignificant(mlir::ArrayRef indices, mlir::ValueRange offsets, llvm::SmallVectorImpl &computedIndices) const; + // Perform splat to match (see below), accepting up to 3 values at most. + void splatToMatch(llvm::SmallVectorImpl &vals) const; private: mlir::Value createArithCmp( @@ -191,6 +209,12 @@ struct MathBuilder final : DialectBuilder { mlir::Value lhs, mlir::Value rhs, mlir::arith::CmpFPredicate pred) const; mlir::Value castToSignless(mlir::Value source, int64_t width) const; mlir::Value castToUnsigned(mlir::Value source, int64_t width) const; + + // If any of the first, second, or third values are vector types, splat the + // other ones to the same VL. Return true if one or more values were splatted. + bool splatToMatch(mlir::Value &first, mlir::Value &second) const; + bool splatToMatch( + mlir::Value &first, mlir::Value &second, mlir::Value &third) const; }; //===----------------------------------------------------------------------===// @@ -245,6 +269,9 @@ struct MemRefBuilder final : DialectBuilder { bool getStaticAndDynamicMemSize(mlir::MemRefType type, llvm::SmallVectorImpl &dims, int64_t &staticSize, IndexExpr &dynSize, int64_t range = 1000) const; + // Same as above, but does not track of dynamic size. + static bool getStaticMemSize( + mlir::MemRefType type, int64_t &staticSize, int64_t range = 1000); // Alloc for static shapes without alignment. mlir::memref::AllocOp alloc(mlir::MemRefType type) const; @@ -269,11 +296,11 @@ struct MemRefBuilder final : DialectBuilder { llvm::SmallVectorImpl &dims, int64_t align = defaultAlign) const; - // Alloc for shapes with alignment and padding for safe full SIMD operations. - // Padding may be added so that every values in the shape may safely be - // computed by a SIMD operation (or possibly multiple ones when simdUnroll>1). - // Minimum alignment is gDefaultAllocAlign. - // Operation does not support layouts at this time. + // Alloc for shapes with alignment and padding for safe full SIMD + // operations. Padding may be added so that every values in the shape may + // safely be computed by a SIMD operation (or possibly multiple ones if + // unrollSIMD>1). Minimum alignment is gDefaultAllocAlign. Operation does + // not support layouts at this time. // // Alloc for static shapes with alignment and SIMD padding. mlir::Value alignedAllocWithSimdPadding(mlir::MemRefType type, int64_t VL = 1, @@ -286,13 +313,13 @@ struct MemRefBuilder final : DialectBuilder { mlir::MemRefType type, int64_t VL = 1, int64_t align = defaultAlign) const; mlir::Value alignedAllocWithSimdPadding(mlir::MemRefType type, - llvm::SmallVectorImpl &dims, int64_t simdVLUnroll = 1, + llvm::SmallVectorImpl &dims, int64_t VL = 1, int64_t align = defaultAlign) const; - // The alloca instruction allocates memory on the stack frame of the currently - // executing function, to be automatically released when this function returns - // to its caller. It is strongly suggested to place alloca instructions - // outside of a loop. + // The alloca instruction allocates memory on the stack frame of the + // currently executing function, to be automatically released when this + // function returns to its caller. It is strongly suggested to place alloca + // instructions outside of a loop. mlir::memref::AllocaOp alloca(mlir::MemRefType type) const; mlir::memref::AllocaOp alignedAlloca( mlir::MemRefType type, int64_t align = defaultAlign) const; @@ -306,20 +333,20 @@ struct MemRefBuilder final : DialectBuilder { // hold the dims, save into it, and the perform the actual reshape. mlir::memref::ReshapeOp reshape(llvm::SmallVectorImpl &outputDims, mlir::Value valToReshape) const; - // Flatten innermost dimensions of a MemRef. User provide the value to reshape - // (valToReshape), its dims (dims), and the number of innermost loops to - // collapse (dimsToFlatten). The function computes the new flattened - // dimensions (flattenDims) and return the flattened value. Values of - // dimsToFlatten are in the [1, rank of input] range. Legal only on types + // Flatten innermost dimensions of a MemRef. User provide the value to + // reshape (valToReshape), its dims (dims), and the number of innermost + // loops to collapse (dimsToFlatten). The function computes the new + // flattened dimensions (flattenDims) and return the flattened value. Values + // of dimsToFlatten are in the [1, rank of input] range. Legal only on types // with identity layouts. mlir::Value reshapeToFlatInnermost(mlir::Value valToReshape, llvm::SmallVectorImpl &dims, llvm::SmallVectorImpl &flattenDims, int64_t dimsToFlatten) const; - // Flatten to a 2D MemRef, with outer dim including outermost dim to axis -1, - // and inner dim including the remaining innermost dims. Values of axis are - // in the [1, rank of input) range. Negative axis values are taken from the - // back. Legal only on types with identity layouts. + // Flatten to a 2D MemRef, with outer dim including outermost dim to axis + // -1, and inner dim including the remaining innermost dims. Values of axis + // are in the [1, rank of input) range. Negative axis values are taken from + // the back. Legal only on types with identity layouts. mlir::Value reshapeToFlat2D(mlir::Value valToReshape, llvm::SmallVectorImpl &dims, llvm::SmallVectorImpl &flattenDims, int64_t axis) const; @@ -401,8 +428,8 @@ struct SCFBuilder final : DialectBuilder { SCFBuilder(const DialectBuilder &db) : DialectBuilder(db) {} virtual ~SCFBuilder() {} - /// Create an if then with optional else. Construct does not generate a result - /// (unlike some scf::if) and introduces the yields automatically. + /// Create an if then with optional else. Construct does not generate a + /// result (unlike some scf::if) and introduces the yields automatically. void ifThenElse(mlir::Value cond, mlir::function_ref thenFn, mlir::function_ref elseFn = nullptr) const; @@ -430,15 +457,21 @@ struct VectorBuilder final : DialectBuilder { using F2 = std::function; enum CombiningKind { ADD, MUL, MAX, MIN, AND, OR, XOR }; + // Check if two types have compatible shapes (assuming that scalar will be + // splatted to the proper vector shape), + static bool compatibleShapes(const mlir::Type t1, const mlir::Type t2); + // Check that the two types have identical elementary types and shapes. + static bool compatibleTypes(const mlir::Type t1, const mlir::Type t2); + // Get the machine SIMD vector length for the given elementary type. // This can help guide certain optimizations. - int64_t getMachineVectorLength(const mlir::Type &elementType) const; - int64_t getMachineVectorLength(const mlir::VectorType &vecType) const; - int64_t getMachineVectorLength(mlir::Value vecValue) const; + int64_t getArchVectorLength(const mlir::Type &elementType) const; + int64_t getArchVectorLength(const mlir::VectorType &vecType) const; + int64_t getArchVectorLength(mlir::Value vecValue) const; - // Vector load: memref is expected to be scalar, will load a vector's worth of - // values: e.g. - // %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>. + // Vector load: memref is expected to be scalar, will load a vector's worth + // of values: e.g. %result = vector.load %base[%i, %j] : + // memref<100x100xf32>, vector<8xf32>. mlir::Value load(mlir::VectorType vecType, mlir::Value memref, mlir::ValueRange indices = {}) const; // When ranks of offsets &inputVecArray, F2 reductionFct, llvm::SmallVectorImpl &outputVecArray); - // Compute a suitable SIMD Vector length (which may be a multiple of the - // hardware vector length, up to maxSimdUnroll times). If the dims are too - // small, return 0 (no suitable simd). The collapsedInnermostLoops parameter - // indicates how many inner dimensions of the memref are considered for - // vectorization. If all of them are considered and padding is possible, then - // we can always generate SIMD code with the maxSIMD unroll factor. Otherwise, - // we must ensure that the cumulative static size (dynamic sizes are ignored - // here ) of the array is a multiple of the Vector Length associated with this - // type. If it is not, then no SIMD code gen is possible (return 0). If it is - // possible, return the largest SIMD unroll factor (starting at maxSimdUnroll) - // that divide the cumulative static size of the memref being collapsed for - // SIMD. - // simdLoopStaticTripCount: provide an estimation of the SIMD loop trip - // count. If runtime, return -1; if cannot simdize, return 0; if compile time - // (or a multiple of a compile time value): return that literal. - int64_t computeSuitableUnrollFactor(VectorMachineSupport *vms, - mlir::MemRefType memRefType, llvm::SmallVectorImpl &memRefDims, - int64_t collapsedInnermostLoops, int64_t maxSimdUnroll, bool canPad, - int64_t &simdLoopStaticTripCount) const; - private: bool isPowerOf2(uint64_t num) const; uint64_t getLengthOf1DVector(mlir::Value vec) const; @@ -566,8 +579,8 @@ struct GenericAffineBuilder final : DialectBuilder { // Affine builder uses affine load and store for memory operations. A later // definition of AffineBuilderKrnlMem will use Krnl load and store for memory -// operations. We recommend to use AffineBuilderKrnlMem when converting the Krnl -// dialect into the affine dialect. +// operations. We recommend to use AffineBuilderKrnlMem when converting the +// Krnl dialect into the affine dialect. using AffineBuilder = GenericAffineBuilder; @@ -607,10 +620,11 @@ struct LLVMBuilder final : DialectBuilder { // CallOp mlir::Value call(mlir::ArrayRef resultTypes, - llvm::StringRef funcName, mlir::ArrayRef inputs) const; + llvm::StringRef funcName, mlir::ArrayRef inputs, + bool isVarArg = false) const; mlir::Value call(mlir::ArrayRef resultTypes, - mlir::FlatSymbolRefAttr funcSymbol, - mlir::ArrayRef inputs) const; + mlir::FlatSymbolRefAttr funcSymbol, mlir::ArrayRef inputs, + bool isVarArg = false) const; // CondBrOp void condBr(mlir::Value cond, mlir::Block *trueBlock, @@ -629,7 +643,8 @@ struct LLVMBuilder final : DialectBuilder { mlir::Value extractValue(mlir::Type resultType, mlir::Value container, llvm::ArrayRef position) const; - // FuncOp + // FuncOp (assume non-variadic functions, otherwise add support like in + // seen in `call` in this file). mlir::LLVM::LLVMFuncOp func(llvm::StringRef name, mlir::Type type, bool createUniqueFunc = false) const; @@ -732,6 +747,12 @@ struct LLVMBuilder final : DialectBuilder { } return symbol + postfix; } + +private: + // Support for calling vararg functions; add the necessary type. + void handleVarArgCall(mlir::LLVM::CallOp &callOp, + mlir::ArrayRef resultTypes, + mlir::ArrayRef inputs) const; }; //===----------------------------------------------------------------------===// @@ -864,3 +885,4 @@ struct MultiDialectBuilder : MultiDialectBuilder { #include "DialectBuilder.hpp.inc" } // namespace onnx_mlir +#endif diff --git a/src/Dialect/Mlir/IndexExpr.hpp b/src/Dialect/Mlir/IndexExpr.hpp index dba5aeef5a..0c867d0586 100644 --- a/src/Dialect/Mlir/IndexExpr.hpp +++ b/src/Dialect/Mlir/IndexExpr.hpp @@ -4,7 +4,7 @@ //===----------------IndexExpr.hpp - Index expression---------------------=== // // -// Copyright 2020-2023 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_INDEX_EXPR_H +#define ONNX_MLIR_INDEX_EXPR_H #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -402,6 +403,8 @@ class IndexExprScope { // IndexExprExpr //===----------------------------------------------------------------------===// +using DimsExpr = llvm::SmallVector; + // Data structure that is the public interface for IndexExpr. It is a shallow // data structure that is simply a pointer to the actual data (IndexExprImpl). class IndexExpr { @@ -879,7 +882,7 @@ inline llvm::SmallVector SymListIE(mlir::ValueRange range) { // Create a list of IndexExpr of kind INDEX_EXPR from another list of IndexExpr. template -void getIndexExprList(llvm::SmallVectorImpl &inputList, +void getIndexExprList(const llvm::SmallVectorImpl &inputList, llvm::SmallVectorImpl &outputList) { outputList.clear(); for (auto item : inputList) @@ -887,14 +890,14 @@ void getIndexExprList(llvm::SmallVectorImpl &inputList, } inline llvm::SmallVector DimListIE( - llvm::SmallVectorImpl &inputList) { + const llvm::SmallVectorImpl &inputList) { llvm::SmallVector outputList; getIndexExprList(inputList, outputList); return outputList; } inline llvm::SmallVector SymListIE( - llvm::SmallVectorImpl &inputList) { + const llvm::SmallVectorImpl &inputList) { llvm::SmallVector outputList; getIndexExprList(inputList, outputList); return outputList; @@ -910,3 +913,4 @@ void getIndexExprListFromShape(mlir::ArrayRef inputList, llvm::SmallVectorImpl &outputList); } // namespace onnx_mlir +#endif diff --git a/src/Dialect/Mlir/IndexExprBuilder.hpp b/src/Dialect/Mlir/IndexExprBuilder.hpp index 3d14303612..5d961a1a68 100644 --- a/src/Dialect/Mlir/IndexExprBuilder.hpp +++ b/src/Dialect/Mlir/IndexExprBuilder.hpp @@ -4,7 +4,7 @@ //===---------------- ONNXShapeHelper.hpp - help for shapes ---------------===// // -// Copyright 2022-2023 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -22,7 +22,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_INDEX_EXPR_BUILDER_H +#define ONNX_MLIR_INDEX_EXPR_BUILDER_H #include @@ -230,3 +231,4 @@ struct IndexExprBuilder : DialectBuilder { }; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/Mlir/IndexExprDetail.hpp b/src/Dialect/Mlir/IndexExprDetail.hpp index 9a0ba29d9a..177462c24c 100644 --- a/src/Dialect/Mlir/IndexExprDetail.hpp +++ b/src/Dialect/Mlir/IndexExprDetail.hpp @@ -5,7 +5,7 @@ //===------------- IndexExprDetail.hpp - Index expression details ---------===// // // -// Copyright 2020-2022 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -14,7 +14,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_INDEX_EXPR_DETAIL_H +#define ONNX_MLIR_INDEX_EXPR_DETAIL_H #include "src/Dialect/Mlir/IndexExpr.hpp" @@ -124,3 +125,4 @@ class IndexExprImpl { }; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/Mlir/VectorMachineSupport.cpp b/src/Dialect/Mlir/VectorMachineSupport.cpp index f25818e9a8..3ac5bdcb01 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.cpp +++ b/src/Dialect/Mlir/VectorMachineSupport.cpp @@ -4,12 +4,15 @@ //===-- VectorMachineSupport.cpp - Helper for what SIMD ops are supported -===// // -// Copyright 2023 The IBM Research Authors. +// Copyright 2023-2024 The IBM Research Authors. // // ============================================================================= #include "src/Dialect/Mlir/VectorMachineSupport.hpp" + #include "mlir/IR/BuiltinTypes.h" +#include "llvm/Support/Debug.h" + #include #define DEBUG_TYPE "dialect_builder" @@ -51,6 +54,7 @@ namespace onnx_mlir { } assert(globalVectorMachineSupport && "failed to allocate vector machine support"); + LLVM_DEBUG(llvm::dbgs() << "use SIMD arch " << getArchName() << "\n"); } /*static*/ void VectorMachineSupport::clearGlobalVectorMachineSupport() { @@ -60,72 +64,69 @@ namespace onnx_mlir { globalVectorMachineSupport = nullptr; } -/*static*/ bool VectorMachineSupport::hasSimd() { - return getGlobalVectorMachineSupport()->VectorRegisterNum() > 0; -} // ============================================================================= // Methods shared among all VectorMachineSupport classes and subclasses -int64_t VectorMachineSupport::getVectorLength(Type elementType) { +int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { if (!hasSimd()) return 0; - int64_t simdBitSize = getVectorBitWidth(); + int64_t simdBitSize = computeArchVectorBitWidth(); int64_t typeBitSize = elementType.getIntOrFloatBitWidth(); assert(simdBitSize >= typeBitSize && simdBitSize % typeBitSize == 0 && "bad machine vector length"); return (simdBitSize / typeBitSize); } -double VectorMachineSupport::getAvgVectorLength(ArrayRef &gops, - ArrayRef &gopsNum, Type elementType, int64_t &vectorizedOpNum, - int64_t &scalarOpNum) { - assert(gopsNum.size() == gops.size() && "expect same length for both lists"); - int64_t gopsSize = gops.size(); +/*static*/ double VectorMachineSupport::getAvgArchVectorLength(GenOpMix &genOps, + Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) { + int64_t size = genOps.size(); if (!hasSimd()) { vectorizedOpNum = 0; - scalarOpNum = gopsSize; - return 0; + scalarOpNum = size; + return 1; } int64_t totProcessedValues = 0.0; vectorizedOpNum = 0; scalarOpNum = 0; // Determine which operations support SIMD and accumulate their vector // lengths. - for (int64_t i = 0; i < gopsSize; ++i) { - int64_t vl = getVectorLength(gops[i], elementType); + for (auto pair : genOps) { + GenericOps genOp = pair.first; + int64_t num = pair.second; + int64_t vl = getArchVectorLength(genOp, elementType); // If past last value, assume 1; otherwise use actual value. - int64_t num = gopsNum[i]; // Accumulate weighted scalar/vectorized num and vl length. if (vl > 0) vectorizedOpNum += num; else scalarOpNum += num; - // For totVL, when an operation is scalar, it still process 1 element + // For VL, when an operation is scalar, it still process 1 element int64_t processedValues = std::max((int64_t)1, vl); totProcessedValues += processedValues * num; } // Compute final values int64_t totNum = vectorizedOpNum + scalarOpNum; - scalarOpNum = gopsSize - vectorizedOpNum; - return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 0.0; + scalarOpNum = size - vectorizedOpNum; + return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0; } // ============================================================================= // IBM Z servers // ============================================================================= -int64_t Z16VectorMachineSupport::getVectorLength( +int64_t Z16VectorMachineSupport::computeArchVectorLength( GenericOps Gop, Type elementType) { int64_t bitWidth = elementType.getIntOrFloatBitWidth(); - int64_t abstractVL = VectorMachineSupport::getVectorLength(elementType); + int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); // Support shared between int and float. switch (Gop) { - // 1 - 16 byte operations. + case GenericOps::ScalarOnlyGop: + return 1; // Must be scalar. case GenericOps::SelectGop: case GenericOps::ShuffleGop: - return abstractVL; + return archVL; // 1 - 16 byte operations. default: // Continue with typed tests. break; @@ -133,8 +134,8 @@ int64_t Z16VectorMachineSupport::getVectorLength( // Support for float. if (isFloat) { - // Supports only 32 and 64 bit Floats; There is support for extended too but - // ignore this for now. + // Supports only 32 and 64 bit Floats; There is support for extended too + // but ignore this for now. if (!(bitWidth == 32 || bitWidth == 64 || (bitWidth == 16 && Gop == GenericOps::ConversionGop))) return UNSUPPORTED; @@ -153,7 +154,7 @@ int64_t Z16VectorMachineSupport::getVectorLength( case GenericOps::MinMaxGop: case GenericOps::MulGop: case GenericOps::SqrtGop: - return abstractVL; + return archVL; default: // Unsupported float op. return UNSUPPORTED; @@ -165,7 +166,7 @@ int64_t Z16VectorMachineSupport::getVectorLength( case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: case GenericOps::LogicalGop: - return abstractVL; + return archVL; // 1 - 8 byte operations. case GenericOps::AbsGop: /* supported via compare and select */ @@ -175,7 +176,7 @@ int64_t Z16VectorMachineSupport::getVectorLength( case GenericOps::MulGop: case GenericOps::ShiftGop: case GenericOps::SumAcrossGop: - return bitWidth <= 64 ? abstractVL : UNSUPPORTED; + return bitWidth <= 64 ? archVL : UNSUPPORTED; default: // Unsupported integer op. return UNSUPPORTED; @@ -188,18 +189,19 @@ int64_t Z16VectorMachineSupport::getVectorLength( // This may be an approximation of the actual capabilities. // ============================================================================= -int64_t SSE42x86VectorMachineSupport::getVectorLength( +int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( GenericOps Gop, mlir::Type elementType) { int64_t bitWidth = elementType.getIntOrFloatBitWidth(); - int64_t abstractVL = VectorMachineSupport::getVectorLength(elementType); + int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); // Support shared between int and float. switch (Gop) { - // 1 - 16 byte operations. + case GenericOps::ScalarOnlyGop: + return 1; // Must be scalar. case GenericOps::SelectGop: case GenericOps::ShuffleGop: - return abstractVL; + return archVL; //// 1 - 16 byte operations. default: // Continue with typed tests. break; @@ -207,8 +209,8 @@ int64_t SSE42x86VectorMachineSupport::getVectorLength( // Support for float. if (isFloat) { - // Supports only 32 and 64 bit Floats; There is support for extended too but - // ignore this for now. + // Supports only 32 and 64 bit Floats; There is support for extended too + // but ignore this for now. if (!(bitWidth == 32 || bitWidth == 64 || (bitWidth == 16 && Gop == GenericOps::ConversionGop))) return UNSUPPORTED; @@ -228,7 +230,7 @@ int64_t SSE42x86VectorMachineSupport::getVectorLength( case GenericOps::RoundGop: case GenericOps::SqrtGop: case GenericOps::SumAcrossGop: - return abstractVL; + return archVL; default: // Unsupported float op. return UNSUPPORTED; @@ -243,23 +245,23 @@ int64_t SSE42x86VectorMachineSupport::getVectorLength( case GenericOps::MinMaxGop: case GenericOps::CompareGop: case GenericOps::AbsGop: - return abstractVL; + return archVL; // 1 - 8 byte operations. case GenericOps::ShiftGop: - return bitWidth <= 64 ? abstractVL : UNSUPPORTED; + return bitWidth <= 64 ? archVL : UNSUPPORTED; // 1 - 4 byte operations. case GenericOps::FmaGop: - return bitWidth <= 32 ? abstractVL : UNSUPPORTED; + return bitWidth <= 32 ? archVL : UNSUPPORTED; // 4 - 16 byte operations. case GenericOps::MulGop: - return bitWidth >= 32 && bitWidth <= 128 ? abstractVL : UNSUPPORTED; + return bitWidth >= 32 && bitWidth <= 128 ? archVL : UNSUPPORTED; // 4 - 8 byte operations. case GenericOps::SumAcrossGop: - return bitWidth >= 32 && bitWidth <= 64 ? abstractVL : UNSUPPORTED; + return bitWidth >= 32 && bitWidth <= 64 ? archVL : UNSUPPORTED; default: // Unsupported integer op. @@ -273,18 +275,19 @@ int64_t SSE42x86VectorMachineSupport::getVectorLength( // This may be an approximation of the actual capabilities. // ============================================================================= -int64_t NeonVectorMachineSupport::getVectorLength( +int64_t NeonVectorMachineSupport::computeArchVectorLength( GenericOps Gop, mlir::Type elementType) { int64_t bitWidth = elementType.getIntOrFloatBitWidth(); - int64_t abstractVL = VectorMachineSupport::getVectorLength(elementType); + int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); // Support shared between int and float. switch (Gop) { - // 1 - 16 byte operations. + case GenericOps::ScalarOnlyGop: + return 1; // Must be scalar. case GenericOps::SelectGop: case GenericOps::ShuffleGop: - return abstractVL; + return archVL; // 1 - 16 byte operations. default: // Continue with typed tests. break; @@ -312,7 +315,7 @@ int64_t NeonVectorMachineSupport::getVectorLength( case GenericOps::RoundGop: case GenericOps::SqrtGop: case GenericOps::SumAcrossGop: - return abstractVL; + return archVL; default: // Unsupported float op. return UNSUPPORTED; @@ -327,23 +330,23 @@ int64_t NeonVectorMachineSupport::getVectorLength( case GenericOps::MinMaxGop: case GenericOps::CompareGop: case GenericOps::AbsGop: - return abstractVL; + return archVL; // 1 - 8 byte operations. case GenericOps::ShiftGop: - return bitWidth <= 64 ? abstractVL : UNSUPPORTED; + return bitWidth <= 64 ? archVL : UNSUPPORTED; // 1 - 4 byte operations. case GenericOps::FmaGop: - return bitWidth <= 32 ? abstractVL : UNSUPPORTED; + return bitWidth <= 32 ? archVL : UNSUPPORTED; // 4 - 16 byte operations. case GenericOps::MulGop: - return bitWidth >= 32 && bitWidth <= 128 ? abstractVL : UNSUPPORTED; + return bitWidth >= 32 && bitWidth <= 128 ? archVL : UNSUPPORTED; // 4 - 8 byte operations. case GenericOps::SumAcrossGop: - return bitWidth >= 32 && bitWidth <= 64 ? abstractVL : UNSUPPORTED; + return bitWidth >= 32 && bitWidth <= 64 ? archVL : UNSUPPORTED; default: // Unsupported integer op. @@ -352,4 +355,27 @@ int64_t NeonVectorMachineSupport::getVectorLength( llvm_unreachable("should have handled all cases above"); } +// ============================================================================= +// Support for Generic Operation Mix + +GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) { + GenOpMix u; + // Pick ops from the first mix. + for (auto pair : mix1) { + GenericOps genOp = pair.first; + int64_t num = pair.second; + u[genOp] = num; + } + // Merge entries from the second mix. + for (auto pair : mix1) { + GenericOps genOp = pair.first; + int64_t num = pair.second; + if (u.find(genOp) != u.end()) + u[genOp] += num; // Has this op already, add to it. + else + u[genOp] = num; + } + return u; +} + } // namespace onnx_mlir diff --git a/src/Dialect/Mlir/VectorMachineSupport.hpp b/src/Dialect/Mlir/VectorMachineSupport.hpp index c88848fa57..bcd2ad1a88 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.hpp +++ b/src/Dialect/Mlir/VectorMachineSupport.hpp @@ -4,7 +4,7 @@ //===-- VectorMachineSupport.hpp - Helper for what SIMD ops are supported -===// // -// Copyright 2023 The IBM Research Authors. +// Copyright 2023-2024 The IBM Research Authors. // // ============================================================================= // @@ -15,7 +15,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_VECTOR_MACHINE_H +#define ONNX_MLIR_VECTOR_MACHINE_H #include "mlir/IR/Types.h" #include "llvm/ADT/SmallVector.h" @@ -32,7 +33,7 @@ namespace onnx_mlir { enum class GenericOps { AbsGop, - ArithmeticGop, /* Simple compute ops: add/sub/neg + ops of same complexity */ + ArithmeticGop, /* Simple compute ops: add/sub/neg + ops of same complexity. */ CeilDivGop, CeilGop, CompareGop, /* All compare operations, signed/unsigned fixed/float. */ @@ -47,11 +48,12 @@ enum class GenericOps { LogGop, LogicalGop, /* All logical ops: and, or, xor, not, nor, nand,... */ MinMaxGop, - MinMaxAcrossGop, /* compute min/max across vector */ + MinMaxAcrossGop, /* Compute min/max across vector. */ MulGop, PowGop, RemGop, RoundGop, + ScalarOnlyGop, /* Any ops that are guaranteed to be scalar on any arch. */ SelectGop, ShiftGop, /* Shift operations: logical/arithmetic. */ ShuffleGop, /* All bit/byte moving operations: shuffle, rotate, shift. */ @@ -62,6 +64,13 @@ enum class GenericOps { TrigHyperbolicGop, /* Hyperbolic trig. */ }; +// Describe the mix of Generic operations in a given kernel. Each generic +// operation is associated with a number, which indicates the number of +// occurrence of that generic op in the given kernel. +using GenOpMix = llvm::SmallDenseMap; + +GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2); + //===----------------------------------------------------------------------===// // Generic vector machine support class, which must be refined for each // supported machine type. @@ -75,51 +84,73 @@ class VectorMachineSupport { // Must call setGlobalVectorMachineSupport once before using any calls below. static void setGlobalVectorMachineSupport( std::string arch, std::string cpu, std::string attr); - // Get the defined vector machine support. - static VectorMachineSupport *getGlobalVectorMachineSupport() { - assert(globalVectorMachineSupport && "vector machine support undefined"); - return globalVectorMachineSupport; - } static void clearGlobalVectorMachineSupport(); + static std::string getArchName() { return vms()->computeArchName(); } + // Determine if the machine has simd. Requires an initialized vector machine // support. - static bool hasSimd(); + static bool hasSimd() { return getArchVectorRegisterNum() > 0; } // When querying Vector length for machines with unsupported simd, UNSUPPORTED // (aka 0) is returned. - static const int64_t UNSUPPORTED = 0; + static const int64_t UNSUPPORTED = 1; // Number of vector registers available. - virtual int64_t VectorRegisterNum() = 0; + static int64_t getArchVectorRegisterNum() { + // Indirection to the object specific to a subclass. + return vms()->computeArchVectorRegisterNum(); + } // Return the bit width of the SIMD unit regardless of the type/operation. // This is an upper bound and does not guarantee that an actual operation can // provide this VL. A value of zero means no SIMD available. - virtual int64_t getVectorBitWidth() = 0; + static int64_t getArchVectorBitWidth() { + // Indirection to the object specific to a subclass. + return vms()->computeArchVectorBitWidth(); + } // Return the number of elements that can be processed in SIMD fashion // regardless of the operation. This is an upper bound and does not guarantee // that an actual operation can provide this VL. A value of zero means no SIMD // available. - virtual int64_t getVectorLength(mlir::Type elementType); + static int64_t getArchVectorLength(mlir::Type elementType) { + // Indirection to the object specific to a subclass. + return vms()->computeArchVectorLength(elementType); + } + // Return the number of elements that can be processed in SIMD fashion if // support exists. A value of zero means no SIMD available. - virtual int64_t getVectorLength(GenericOps gop, mlir::Type elementType) = 0; + static int64_t getArchVectorLength(GenericOps gop, mlir::Type elementType) { + // Indirection to the object specific to a subclass. + return vms()->computeArchVectorLength(gop, elementType); + } // Analyze the benefits of using SIMD on a list of generic ops in an algorithm // where each op on the list occurs a given number of times. The function // returns the weighted average vector length among the operations listed in - // the gops list, where each operation gops[i] occur exactly gopsNum[i] times - // in the algorithm. Note that scalar operation have a vector length of - // one in the weighted average as they still contribute one result. The opNums - // are also weighted by the gopsNum to better represent the mix of - // vectorized and scalar operations present in the algorithm. - double getAvgVectorLength(mlir::ArrayRef &gops, - mlir::ArrayRef &gopsNum, mlir::Type elementType, + // the GenOps list, where each entry is a pair of generic operation and the + // number of times that generic operation was found. Note that scalar + // operation have a vector length of one in the weighted average as they still + // contribute one result. + static double getAvgArchVectorLength(GenOpMix &genOps, mlir::Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum); +protected: + // Virtual functions that do the actual work. Called by the "get" functions. + virtual std::string computeArchName() = 0; + virtual int64_t computeArchVectorRegisterNum() = 0; + virtual int64_t computeArchVectorBitWidth() = 0; + virtual int64_t computeArchVectorLength(mlir::Type elementType); + virtual int64_t computeArchVectorLength( + GenericOps gop, mlir::Type elementType) = 0; + private: - static VectorMachineSupport *globalVectorMachineSupport; + static VectorMachineSupport *vms() { + assert(globalVectorMachineSupport && "vector machine support undefined"); + return globalVectorMachineSupport; + } + + static VectorMachineSupport *globalVectorMachineSupport; // Init to null. }; // No support for SIMD. @@ -128,12 +159,14 @@ class NoVectorMachineSupport : public VectorMachineSupport { NoVectorMachineSupport() = default; virtual ~NoVectorMachineSupport() = default; - int64_t VectorRegisterNum() override { return 0; } - int64_t getVectorBitWidth() override { return 0; } - int64_t getVectorLength(mlir::Type elementType) override { + std::string computeArchName() override { return "no_vector"; } + int64_t computeArchVectorRegisterNum() override { return 0; } + int64_t computeArchVectorBitWidth() override { return 0; } + int64_t computeArchVectorLength(mlir::Type elementType) override { return UNSUPPORTED; } - int64_t getVectorLength(GenericOps gop, mlir::Type elementType) override { + int64_t computeArchVectorLength( + GenericOps gop, mlir::Type elementType) override { return UNSUPPORTED; } }; @@ -145,9 +178,11 @@ class Z16VectorMachineSupport : public VectorMachineSupport { Z16VectorMachineSupport() = default; virtual ~Z16VectorMachineSupport() = default; - int64_t VectorRegisterNum() override { return 32; } - int64_t getVectorBitWidth() override { return 128; } - int64_t getVectorLength(GenericOps gop, mlir::Type elementType) override; + std::string computeArchName() override { return "z16"; } + int64_t computeArchVectorRegisterNum() override { return 32; } + int64_t computeArchVectorBitWidth() override { return 128; } + int64_t computeArchVectorLength( + GenericOps gop, mlir::Type elementType) override; }; // TODO: create models for z14 and z15. @@ -160,9 +195,11 @@ class SSE42x86VectorMachineSupport : public VectorMachineSupport { SSE42x86VectorMachineSupport() = default; virtual ~SSE42x86VectorMachineSupport() = default; - int64_t VectorRegisterNum() override { return 16; } - int64_t getVectorBitWidth() override { return 128; } - int64_t getVectorLength(GenericOps gop, mlir::Type elementType) override; + std::string computeArchName() override { return "x86-sse4.2"; } + int64_t computeArchVectorRegisterNum() override { return 16; } + int64_t computeArchVectorBitWidth() override { return 128; } + int64_t computeArchVectorLength( + GenericOps gop, mlir::Type elementType) override; }; class AVX2x86VectorMachineSupport : public SSE42x86VectorMachineSupport { @@ -170,7 +207,8 @@ class AVX2x86VectorMachineSupport : public SSE42x86VectorMachineSupport { AVX2x86VectorMachineSupport() = default; virtual ~AVX2x86VectorMachineSupport() = default; - int64_t getVectorBitWidth() override { return 258; } + std::string computeArchName() override { return "x86-avx2"; } + int64_t computeArchVectorBitWidth() override { return 258; } }; // Support for Arm 64 @@ -180,9 +218,12 @@ class NeonVectorMachineSupport : public VectorMachineSupport { NeonVectorMachineSupport() = default; virtual ~NeonVectorMachineSupport() = default; - int64_t VectorRegisterNum() override { return 32; } - int64_t getVectorBitWidth() override { return 128; } - int64_t getVectorLength(GenericOps gop, mlir::Type elementType) override; + std::string computeArchName() override { return "arm64-neon"; } + int64_t computeArchVectorRegisterNum() override { return 32; } + int64_t computeArchVectorBitWidth() override { return 128; } + int64_t computeArchVectorLength( + GenericOps gop, mlir::Type elementType) override; }; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/DialectBuilder.hpp b/src/Dialect/ONNX/DialectBuilder.hpp index 5bfdfca3aa..9ff98a3755 100644 --- a/src/Dialect/ONNX/DialectBuilder.hpp +++ b/src/Dialect/ONNX/DialectBuilder.hpp @@ -4,7 +4,7 @@ //===----------- DialectBuilder.hpp - Builder for ONNX dialects -----------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_DIALECT_BUILDER_H +#define ONNX_MLIR_ONNX_DIALECT_BUILDER_H #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" @@ -327,3 +328,4 @@ struct IndexExprBuilderForAnalysis : IndexExprBuilder { #include "DialectBuilder.hpp.inc" } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/ElementsAttr/BType.hpp b/src/Dialect/ONNX/ElementsAttr/BType.hpp index bb337d7455..dbe4cc9a9c 100644 --- a/src/Dialect/ONNX/ElementsAttr/BType.hpp +++ b/src/Dialect/ONNX/ElementsAttr/BType.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_B_TYPE_H +#define ONNX_MLIR_B_TYPE_H #include "src/Support/SmallFP.hpp" @@ -267,4 +268,5 @@ auto dispatchByBType(BType btype, Action &&act) { #undef ACT } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir +#endif \ No newline at end of file diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp index dd1ab9e1f1..3739904c17 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DISPOSABLE_ELEMENTS_ATTR_H +#define ONNX_MLIR_DISPOSABLE_ELEMENTS_ATTR_H #include "src/Dialect/ONNX/ElementsAttr/BType.hpp" #include "src/Dialect/ONNX/ElementsAttr/WideNum.hpp" @@ -326,3 +327,4 @@ class DisposableElementsAttr } // namespace mlir MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::DisposableElementsAttr) +#endif diff --git a/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp b/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp index 9a538c06ef..6c9972853d 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp @@ -25,7 +25,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DISPOSABLE_POOL_H +#define ONNX_MLIR_DISPOSABLE_POOL_H #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" @@ -97,3 +98,4 @@ class DisposablePool : public mlir::DialectInterface::Base { }; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index 0414aaabe6..f7276b6ebb 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ELEM_ATTR_BUILDER_H +#define ONNX_MLIR_ELEM_ATTR_BUILDER_H #include "src/Dialect/ONNX/ElementsAttr/BType.hpp" #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" @@ -272,3 +273,4 @@ class ElementsAttrBuilder { }; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp index db16daf380..ab394ee636 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ELEM_ATTR_HELPER_H +#define ONNX_MLIR_ELEM_ATTR_HELPER_H #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" #include "src/Dialect/ONNX/ElementsAttr/WideNum.hpp" @@ -51,3 +52,4 @@ void readElementsWideNums( #include "ElementsAttrHelper.hpp.inc" } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/ElementsAttr/Strides.hpp b/src/Dialect/ONNX/ElementsAttr/Strides.hpp index 8c2e4ce7b4..1402207a7b 100644 --- a/src/Dialect/ONNX/ElementsAttr/Strides.hpp +++ b/src/Dialect/ONNX/ElementsAttr/Strides.hpp @@ -34,7 +34,8 @@ // DenseElementsAttr::reshape() which always reuses its linear array. //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_STRIDES_H +#define ONNX_MLIR_STRIDES_H #include "src/Support/Arrays.hpp" @@ -109,4 +110,5 @@ void restrideArray(llvm::ArrayRef shape, castMutableArrayRef(dst)); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir +#endif \ No newline at end of file diff --git a/src/Dialect/ONNX/ElementsAttr/StridesRange.hpp b/src/Dialect/ONNX/ElementsAttr/StridesRange.hpp index 8816b43ced..1994a76777 100644 --- a/src/Dialect/ONNX/ElementsAttr/StridesRange.hpp +++ b/src/Dialect/ONNX/ElementsAttr/StridesRange.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_STRIDES_RANGE_H +#define ONNX_MLIR_STRIDES_RANGE_H #include "mlir/IR/BuiltinTypeInterfaces.h" #include "llvm/ADT/ArrayRef.h" @@ -162,4 +163,5 @@ class StridesRange { // Include template implementations. #include "StridesRange.hpp.inc" -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir +#endif \ No newline at end of file diff --git a/src/Dialect/ONNX/ElementsAttr/WideNum.hpp b/src/Dialect/ONNX/ElementsAttr/WideNum.hpp index bc3adde34a..b7bb83766a 100644 --- a/src/Dialect/ONNX/ElementsAttr/WideNum.hpp +++ b/src/Dialect/ONNX/ElementsAttr/WideNum.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_WIDE_NUM_H +#define ONNX_MLIR_WIDE_NUM_H #include "src/Dialect/ONNX/ElementsAttr/BType.hpp" @@ -223,3 +224,4 @@ inline auto wideZeroDispatch(mlir::Type type, Action &&act) { #include "WideNum.hpp.inc" } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/ONNXAttributes.hpp b/src/Dialect/ONNX/ONNXAttributes.hpp index 22915e83a5..d038cc8cbe 100644 --- a/src/Dialect/ONNX/ONNXAttributes.hpp +++ b/src/Dialect/ONNX/ONNXAttributes.hpp @@ -8,9 +8,11 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_ATTRIBUTES_H +#define ONNX_MLIR_ONNX_ATTRIBUTES_H #include "mlir/IR/Attributes.h" #define GET_ATTRDEF_CLASSES #include "src/Dialect/ONNX/ONNXAttributes.hpp.inc" +#endif diff --git a/src/Dialect/ONNX/ONNXDialect.hpp b/src/Dialect/ONNX/ONNXDialect.hpp index 5b94b15058..9ad86b40b0 100644 --- a/src/Dialect/ONNX/ONNXDialect.hpp +++ b/src/Dialect/ONNX/ONNXDialect.hpp @@ -4,7 +4,7 @@ //===-------------------------- ONNXDialect.hpp ---------------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,8 +12,10 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_DIALECT_H +#define ONNX_MLIR_ONNX_DIALECT_H #include "mlir/IR/Dialect.h" #include "src/Dialect/ONNX/ONNXDialect.hpp.inc" +#endif diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.hpp b/src/Dialect/ONNX/ONNXDimAnalysis.hpp index ec94eaa6df..e7df7ccb09 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.hpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.hpp @@ -4,7 +4,7 @@ //===-------- ONNXDimAnalysis.hpp - ONNX Dimension Analysis ---------------===// // -// Copyright 2022-2023 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_DIM_ANALYSIS_H +#define ONNX_MLIR_ONNX_DIM_ANALYSIS_H #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" @@ -129,3 +130,4 @@ class DimAnalysis { }; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/ONNXLayoutHelper.hpp b/src/Dialect/ONNX/ONNXLayoutHelper.hpp index 98b8de7bc5..eeb25632b0 100644 --- a/src/Dialect/ONNX/ONNXLayoutHelper.hpp +++ b/src/Dialect/ONNX/ONNXLayoutHelper.hpp @@ -4,13 +4,14 @@ //===---------- ONNXLayoutHelper.hpp - ONNX Layout Helper -----------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_LAYOUT_HELPER_H +#define ONNX_MLIR_ONNX_LAYOUT_HELPER_H #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/PatternMatch.h" @@ -23,3 +24,4 @@ const std::string LAYOUT_KCMN4C4K = "KCMN4C4K"; const std::string LAYOUT_STANDARD = "STANDARD"; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/ONNXOps.hpp b/src/Dialect/ONNX/ONNXOps.hpp index f016557608..febb5207c0 100644 --- a/src/Dialect/ONNX/ONNXOps.hpp +++ b/src/Dialect/ONNX/ONNXOps.hpp @@ -4,7 +4,7 @@ //===-------------------- ONNXOps.hpp - ONNX Operations -------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_OPS_H +#define ONNX_MLIR_ONNX_OPS_H #include "mlir/Dialect/Func/IR/FuncOps.h" #include "src/Dialect/ONNX/ONNXAttributes.hpp" @@ -30,3 +31,4 @@ static constexpr int CURRENT_ONNX_OPSET = 20; #define GET_OP_CLASSES #include "src/Dialect/ONNX/ONNXOps.hpp.inc" +#endif diff --git a/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.hpp b/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.hpp index 9a335d4bf4..6689242fb8 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_EINSUM_HELPER_H +#define ONNX_MLIR_EINSUM_HELPER_H #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLFunctionalExtras.h" @@ -70,3 +71,4 @@ mlir::FailureOr inferSignature( } // namespace einsum } // namespace onnx_mlir +#endif \ No newline at end of file diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index 2c2d089162..3d827f85d5 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -2,9 +2,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===------- ONNXOpsHelper.hpp - Helper functions for ONNX dialects -------===// +//===---------- OpHelper.hpp - Helper functions for ONNX dialects ---------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_OPS_HELPER_H +#define ONNX_MLIR_OPS_HELPER_H #include "mlir/Dialect/Traits.h" #include "mlir/IR/AffineExpr.h" @@ -309,3 +310,4 @@ std::string getNodeNameInPresenceOfOpt( #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc" } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc index 8ad0944af9..b0fa82f8c4 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc @@ -4,7 +4,7 @@ //===----- ONNXOpsHelper.hpp.inc - Helper functions for ONNX dialects -----===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // diff --git a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp index f383509561..7f27d19ebb 100644 --- a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp +++ b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp @@ -4,7 +4,7 @@ //===------------------ DynamicQuantizeLinear.cpp - ONNX Operations -------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp index c3cb67a1a7..7561575f11 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp @@ -4,7 +4,7 @@ //===----------------ONNXShapeHelper.cpp - help for shapes----------------=== // // -// Copyright 2020-2023 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -517,8 +517,23 @@ bool ONNXBroadcastOpShapeHelper::hasManageableBroadcastForInnerDims( // no need to collapse the 1 dimensions... it brings no advantages. So by // skipping the updating of collapsedInnermostLoops here, we will omit // these leading ones. - if (numOfOnes == dimNum) { + + // Revision: it is actually good to detects 1s everywhere as we can + // collapse the loop and have less overhead. +#define REVISION_COLLAPSE_ALL_ONES 1 + bool allOnes = numOfOnes == dimNum; + if (allOnes) { +#if REVISION_COLLAPSE_ALL_ONES + // No need to update the sizes as dim is all ones. + collapsedInnermostLoops = -r; + LLVM_DEBUG(llvm::dbgs() << " SUCCESS (all ones) at collapsing " + << collapsedInnermostLoops + << " inner loops with cumulative static size of " + << collapsedLiteralSize << "\n\n"); + +#else LLVM_DEBUG(llvm::dbgs() << " all ones, done\n"); +#endif continue; } @@ -530,7 +545,8 @@ bool ONNXBroadcastOpShapeHelper::hasManageableBroadcastForInnerDims( LLVM_DEBUG(llvm::dbgs() << " check non-scalar compatibility\n"); // For all non scalars... for (int64_t d = 0; d < dimNum; ++d) { - // Consider only dims d that are not scalar, and skip d == nonScalarID. + // Consider only dims d that are not scalar, and skip d == + // nonScalarID. if (isOne[d] || d == nonScalarID) continue; // Compare nonScalarID with d @@ -551,8 +567,8 @@ bool ONNXBroadcastOpShapeHelper::hasManageableBroadcastForInnerDims( << nonScalarID << " & " << d << "; abort\n"); return collapsedInnermostLoops > 0; } - // We could not determine compatibility with literals, try deducing info - // with dim analysis, if available. + // We could not determine compatibility with literals, try deducing + // info with dim analysis, if available. if (canUseDimAnalysis && /* Use negative index convention here as operands may have fewer than outputRank dimensions */ @@ -571,7 +587,7 @@ bool ONNXBroadcastOpShapeHelper::hasManageableBroadcastForInnerDims( } // End for all non-scalars, } // End testing non-scalar compatibility. - // 4) Since we have at least one non-scalar, + // 4) Since we have at least one non-scalar // 4.1) all the scalar inputs are now marked as having a broadcast. // 4.2) any inputs with a one that is not a scalar has a new broadcast, // which is not allowed as only scalars can be broadcast to be @@ -585,10 +601,10 @@ bool ONNXBroadcastOpShapeHelper::hasManageableBroadcastForInnerDims( } else if (isOne[d]) { // Is one but is not a scalar. // Case 1x4x1, 2x4x1, and 1x1x1: no broadcast at r==-1, broadcast at // r==-2 for last entry, no broadcast for the others. At r==-3, - // continued broadcast for last entry, but first entry has new broadcast - // to size 2 (i.e. isOne[0] is true, and isScalar[0] is false). We - // cannot manage this. Abort at this rank r; thus stops at previous - // iteration of r. + // continued broadcast for last entry, but first entry has new + // broadcast to size 2 (i.e. isOne[0] is true, and isScalar[0] is + // false). We cannot manage this. Abort at this rank r; thus stops at + // previous iteration of r. LLVM_DEBUG(llvm::dbgs() << " one and no scalar" << d << "; abort\n"); return collapsedInnermostLoops > 0; } diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp index d10be492ff..01a8943ead 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_SHAPE_HELPER_H +#define ONNX_MLIR_SHAPE_HELPER_H #include @@ -192,6 +193,9 @@ struct ONNXBroadcastOpShapeHelper : public ONNXOpShapeHelper { // has MB, and if it does, we then attempt to test at the next innermost // level... until we fail or we run out of dimensions. // + // Below: ?1 and ?2 indicate 2 dynamic dimensions, which may/may not be + // guaranteed to be equal depending on what dynamic analysis says. + // // Manageable broadcast (MB) is either that: // 1) we have no broadcast up to that level, or // 2) we have scalars up to that level being broadcasted. @@ -226,6 +230,9 @@ struct ONNXBroadcastOpShapeHelper : public ONNXOpShapeHelper { // - (1,3) and (1, 1) have MB at CIL 1; technically, CIL 2 is also a MB but // there is nothing to be gained by collapsing dimensions where all // inputs have dimensions of 1. We thus do not include them in our CILs. + // Revision: it is actually good to detects 1s everywhere as we can + // collapse the loop and have less overhead. + virtual bool hasManageableBroadcastForInnerDims( int64_t &collapsedInnermostLoops, int64_t &collapsedLiteralSize, IndexExpr &collapsedDynamicSize, DimAnalysis *dimAnalysis); @@ -943,3 +950,4 @@ void SaveOnnxAttrInOp(mlir::Operation *op, } } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/ONNXTypes.hpp b/src/Dialect/ONNX/ONNXTypes.hpp index 41667cf73b..797bf2a405 100644 --- a/src/Dialect/ONNX/ONNXTypes.hpp +++ b/src/Dialect/ONNX/ONNXTypes.hpp @@ -4,7 +4,7 @@ //===--------------------- ONNXTypes.hpp - ONNX Types ---------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,9 +12,11 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_TYPES_H +#define ONNX_MLIR_ONNX_TYPES_H #include "mlir/IR/Types.h" #define GET_TYPEDEF_CLASSES #include "src/Dialect/ONNX/ONNXTypes.hpp.inc" +#endif diff --git a/src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp b/src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp index b7ca6e4fdb..4e5639ec64 100644 --- a/src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ONNX_ELEMENTS_ATTR_H +#define ONNX_MLIR_ONNX_ELEMENTS_ATTR_H #include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp" @@ -26,3 +27,4 @@ struct OnnxElementsAttrBuilder : ElementsAttrBuilder { }; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/Transforms/ConstProp.cpp b/src/Dialect/ONNX/Transforms/ConstProp.cpp index 1d3f7f7ecc..49d4855042 100644 --- a/src/Dialect/ONNX/Transforms/ConstProp.cpp +++ b/src/Dialect/ONNX/Transforms/ConstProp.cpp @@ -19,6 +19,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" @@ -122,6 +123,16 @@ Value createReplacingConstantOp( template using EnableNotBool = std::enable_if_t>; +template +using EnableBool = std::enable_if_t>; + +template +using EnableInteger = + std::enable_if_t && !std::is_same_v>; + +template +using EnableFloatingPoint = std::enable_if_t>; + /// Checks whether a variadic value is produced by dense ONNXConstantOps. bool isVariadicOperandFromDenseONNXConstantOp(ValueRange operands) { return llvm::all_of(operands, [](Value v) { return isDenseONNXConstant(v); }); @@ -134,6 +145,47 @@ Value ConstZeroTensor( type, rewriter.getZeroAttr(type.getElementType()))); } +template , + typename GetIntConstFunc = std::function> +Value getClipConstantOfType(PatternRewriter &rewriter, ShapedType type, + Location loc, GetFPConstFunc fpConstantFunc, bool isNegative, + GetIntConstFunc intConstantFunc) { + OnnxBuilder create(rewriter, loc); + auto elemType = type.getElementType(); + if (auto floatType = dyn_cast(elemType)) { + auto fpValue = + fpConstantFunc(floatType.getFloatSemantics(), /*Negative=*/isNegative); + return create.constant(DenseElementsAttr::get( + RankedTensorType::get({}, elemType), llvm::ArrayRef(fpValue))); + } + auto intValue = intConstantFunc(elemType.getIntOrFloatBitWidth()); + return create.constant(DenseElementsAttr::get( + RankedTensorType::get({}, elemType), llvm::ArrayRef(intValue))); +} + +Value createMaximumValueForClip( + PatternRewriter &rewriter, ShapedType type, Value value) { + + // Return 'value' if exists, as there is no need to clip to largest. + if (!isNoneValue(value)) + return value; + + return getClipConstantOfType(rewriter, type, value.getLoc(), + llvm::APFloat::getLargest, false, llvm::APInt::getMaxValue); +} + +Value createMinimumValueForClip( + PatternRewriter &rewriter, ShapedType type, Value value) { + + // Return 'value' if exists, as there is no need to clip to lowest. + if (!isNoneValue(value)) + return value; + + return getClipConstantOfType(rewriter, type, value.getLoc(), + llvm::APFloat::getLargest, true, llvm::APInt::getMinValue); +} + WideNum asWideNum(double n, Type elemType) { return wideZeroDispatch(elemType, [n](auto wideZero) { using cpptype = decltype(wideZero); @@ -203,7 +255,41 @@ struct ElementWiseBinaryOpImpl> { template struct ElementWiseBinaryOpImpl> { - static T eval(T lhs, T rhs) { return lhs / rhs; } + static T eval(T lhs, T rhs) { + if constexpr (std::is_integral_v) { + if (rhs == 0) { + // Undefined behavior. We can return any value. + // Performing the divison would crash. + return lhs; + } + } + return lhs / rhs; + } +}; + +template +struct ElementWiseBinaryOpImpl> { + static T eval(T lhs, T rhs) { return lhs & rhs; } +}; + +template +struct ElementWiseBinaryOpImpl> { + static T eval(T lhs, T rhs) { return lhs | rhs; } +}; + +template +struct ElementWiseBinaryOpImpl> { + static T eval(T lhs, T rhs) { return lhs && rhs; } +}; + +template +struct ElementWiseBinaryOpImpl> { + static T eval(T lhs, T rhs) { return lhs || rhs; } +}; + +template +struct ElementWiseBinaryOpImpl> { + static T eval(T lhs, T rhs) { return lhs != rhs; } }; template @@ -340,11 +426,56 @@ struct ElementWiseUnaryOpImpl { static T eval(T val) { llvm_unreachable("unsupported op or type"); } }; +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return ~val; } +}; + +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return ceil(val); } +}; + +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return cos(val); } +}; + +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return std::erf(val); } +}; + +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return std::exp(val); } +}; + +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return floor(val); } +}; + +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return std::log(val); } +}; + template struct ElementWiseUnaryOpImpl> { static T eval(T val) { return -val; } }; +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return !val; } +}; + +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return sin(val); } +}; + template <> struct ElementWiseUnaryOpImpl { static double eval(double val) { return sqrt(val); } diff --git a/src/Dialect/ONNX/Transforms/ConstProp.hpp b/src/Dialect/ONNX/Transforms/ConstProp.hpp index 4494c1e921..1c54be9e5b 100644 --- a/src/Dialect/ONNX/Transforms/ConstProp.hpp +++ b/src/Dialect/ONNX/Transforms/ConstProp.hpp @@ -2,7 +2,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -#pragma once +#ifndef ONNX_MLIR_CONST_PROP_H +#define ONNX_MLIR_CONST_PROP_H #include "mlir/IR/PatternMatch.h" @@ -12,3 +13,4 @@ namespace onnx_mlir { void getConstPropONNXToONNXPatterns(mlir::RewritePatternSet &patterns); } // namespace onnx_mlir +#endif \ No newline at end of file diff --git a/src/Dialect/ONNX/Transforms/ConstProp.td b/src/Dialect/ONNX/Transforms/ConstProp.td index 9afefef816..d712e0a1b2 100644 --- a/src/Dialect/ONNX/Transforms/ConstProp.td +++ b/src/Dialect/ONNX/Transforms/ConstProp.td @@ -111,6 +111,17 @@ class EqualString : Constraint>; // Creation helpers: +def CreateMaximumValueForClip: NativeCodeCall< + "createMaximumValueForClip($_builder, cast($0.getType()), $1)" +>; + +def CreateMinimumValueForClip: NativeCodeCall< + "createMinimumValueForClip($_builder, cast($0.getType()), $1)" +>; + +def CreateONNXMinOp : NativeCodeCall<"$_builder.create($_loc, $0.getType(), ValueRange({$1, $2}))">; +def CreateONNXMaxOp : NativeCodeCall<"$_builder.create($_loc, $0.getType(), ValueRange({$1, $2}))">; + def CreateZeroTensorOfType: NativeCodeCall< "ConstZeroTensor($_builder, $_loc, mlir::cast($0.getType()))" >; @@ -124,9 +135,36 @@ def CreateSubOfTwoConst : def CreateCastOfConst : NativeCodeCall<"ConstPropCast($_builder, $0, $1, $2, $3)">; +def CreateBitwiseNotOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + +def CreateCeilOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + +def CreateCosOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + +def CreateErfOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + +def CreateExpOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + +def CreateFloorOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + +def CreateLogOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + def CreateNegOfConst : NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; +def CreateNotOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + +def CreateSinOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + def CreateSqrtOfConst : NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; @@ -166,6 +204,21 @@ def CreateLessOrEqualOfTwoConst : def CreateGreaterOrEqualOfTwoConst : NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; +def CreateBitwiseAndOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + +def CreateBitwiseOrOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + +def CreateAndOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + +def CreateOrOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + +def CreateXorOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + def CreatePowOfTwoConst : NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; @@ -351,6 +404,62 @@ def CastofConst : NamedPat<"CastofConst", (CreateCastOfConst $castOp, $input, $saturate, $to), [(IsFromDenseONNXConstantOp:$input)]>; +// Constant Propagation for BitwiseNot +def BitwiseNotConstProp : NamedPat<"BitwiseNotofConst", + // From bitwise_not(c). + (ONNXBitwiseNotOp:$bitwiseNotOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To ~c + (CreateBitwiseNotOfConst $bitwiseNotOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + +// Constant Propagation for Ceil +def CeilConstProp : NamedPat<"CeilofConst", + // From ceil(c). + (ONNXCeilOp:$ceilOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c + (CreateCeilOfConst $ceilOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + +// Constant Propagation for Cos +def CosConstProp : NamedPat<"CosofConst", + // From cos(c). + (ONNXCosOp:$cosOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c. + (CreateCosOfConst $cosOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + +// Constant Propagation for Erf +def ErfConstProp : NamedPat<"ErfofConst", + // From erf(c) + (ONNXErfOp:$erfOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c. + (CreateErfOfConst $erfOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + +// Constant Propagation for Exp +def ExpConstProp : NamedPat<"ExpofConst", + // From exp(c). + (ONNXExpOp:$expOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c. + (CreateExpOfConst $expOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + +// Constant Propagation for Floor +def FloorConstProp : NamedPat<"FloorofConst", + // From floor(c). + (ONNXFloorOp:$floorOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c. + (CreateFloorOfConst $floorOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + +// Constant Propagation for Log +def LogConstProp : NamedPat<"LogofConst", + // From log(c) + (ONNXLogOp:$logOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c. + (CreateLogOfConst $logOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + // Neg of constant is simply -const def NegofConst : NamedPat<"NegofConst", // From - (c) @@ -359,6 +468,30 @@ def NegofConst : NamedPat<"NegofConst", (CreateNegOfConst $negOp, $input), [(IsFromDenseONNXConstantOp:$input)]>; +// Not of constant is simply !const +def NotofConst : NamedPat<"NotofConst", + // From c + (ONNXNotOp:$notOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To !c + (CreateNotOfConst $notOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + +// Constant Propagation for Sin +def SinConstProp : NamedPat<"SinofConst", + // From sin(c). + (ONNXSinOp:$sinOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c + (CreateSinOfConst $sinOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + +// Constant Propagation for Reciprocal +def ReciprocalConstProp : NamedPat<"ReciprocalofConst", + // From 1/c. + (ONNXReciprocalOp:$reciprocalOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c + (CreateReciprocalOfConst $reciprocalOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + // Change a subtraction of a constant c by an addition of -c. Helpfull to combine // with other add optimizations. def SubConstToNeg : NamedPat<"SubConstToNeg", @@ -535,6 +668,66 @@ def DivOnesOnRhs : NamedPat<"DivOnesOnRhs", (ValuesHaveSameType $result, $x) ]>; +//===----------------------------------------------------------------------===// +// Constant propagation for ONNXBitwiseAndOp +//===----------------------------------------------------------------------===// + +def BitwiseAndConstPropPattern : NamedPat<"BitwiseAndConstPropPattern", + (ONNXBitwiseAndOp:$result + (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_), + (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), + (CreateBitwiseAndOfTwoConst $result, $lhs, $rhs), + [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), + (SatisfiesExpansionBound:$result)]>; + +//===----------------------------------------------------------------------===// +// Constant propagation for ONNXBitwiseOrOp +//===----------------------------------------------------------------------===// + +def BitwiseOrConstPropPattern : NamedPat<"BitwiseOrConstPropPattern", + (ONNXBitwiseOrOp:$result + (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_), + (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), + (CreateBitwiseOrOfTwoConst $result, $lhs, $rhs), + [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), + (SatisfiesExpansionBound:$result)]>; + +//===----------------------------------------------------------------------===// +// Constant propagation for ONNXAndOp +//===----------------------------------------------------------------------===// + +def AndConstPropPattern : NamedPat<"AndConstPropPattern", + (ONNXAndOp:$result + (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_), + (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), + (CreateAndOfTwoConst $result, $lhs, $rhs), + [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), + (SatisfiesExpansionBound:$result)]>; + +//===----------------------------------------------------------------------===// +// Constant propagation for ONNXOrOp +//===----------------------------------------------------------------------===// + +def OrConstPropPattern : NamedPat<"OrConstPropPattern", + (ONNXOrOp:$result + (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_), + (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), + (CreateOrOfTwoConst $result, $lhs, $rhs), + [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), + (SatisfiesExpansionBound:$result)]>; + +//===----------------------------------------------------------------------===// +// Constant propagation for ONNXorOp +//===----------------------------------------------------------------------===// + +def XorConstPropPattern : NamedPat<"XorConstPropPattern", + (ONNXXorOp:$result + (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_), + (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), + (CreateXorOfTwoConst $result, $lhs, $rhs), + [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), + (SatisfiesExpansionBound:$result)]>; + //===----------------------------------------------------------------------===// // Constant propagation for ONNXEqualOp //===----------------------------------------------------------------------===// @@ -611,6 +804,29 @@ def ModConstPropPattern : NamedPat<"ModConstPropPattern", [(IsFromDenseONNXConstantOp:$A), (IsFromDenseONNXConstantOp:$B), (SatisfiesExpansionBound:$modOp)]>; +//===----------------------------------------------------------------------===// +// Pattern for Clip. +// Converts clip into min(max(c0, c1), c2) as Min and Max +// already have their own foldings. +// If c1 or c2 isn't defined, default max to numeric_limits::max() and +// min to numeric_limits::lowest() values according to the element type. +//===----------------------------------------------------------------------===// + +// Constant Propagation for Clip +def ClipConstProp : NamedPat<"ClipConstProp", + // From clip(c0, c1, c2). + (ONNXClipOp:$clipOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_), + $min, $max), + // To min(max(c0, c1), c2). + (CreateONNXMinOp + $clipOp, + (CreateONNXMaxOp $clipOp, $input, (CreateMinimumValueForClip $input, $min)), + (CreateMaximumValueForClip $input, $max)), + // Clip constraints + [(IsFromDenseONNXConstantOp:$input), + (IsFromDenseONNXConstantOpOrNone:$min), (IsFromDenseONNXConstantOpOrNone:$max), + (SatisfiesExpansionBound:$clipOp)]>; + //===----------------------------------------------------------------------===// // Patterns for Where. //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/Transforms/ConvOpt.hpp b/src/Dialect/ONNX/Transforms/ConvOpt.hpp index a3d7105d80..96b5fbf239 100644 --- a/src/Dialect/ONNX/Transforms/ConvOpt.hpp +++ b/src/Dialect/ONNX/Transforms/ConvOpt.hpp @@ -2,7 +2,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -#pragma once +#ifndef ONNX_MLIR_CONV_OPT_H +#define ONNX_MLIR_CONV_OPT_H #include "mlir/IR/PatternMatch.h" @@ -13,3 +14,4 @@ void getConvOptONNXToONNXPatterns( bool enableSimdDataLayoutOpt, mlir::RewritePatternSet &patterns); } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/Transforms/Decompose.hpp b/src/Dialect/ONNX/Transforms/Decompose.hpp index f91b9a7ec5..b73e6ecdb4 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.hpp +++ b/src/Dialect/ONNX/Transforms/Decompose.hpp @@ -4,7 +4,7 @@ //===----------- ONNXDecompose.cpp - ONNX High Level Rewriting ------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -17,7 +17,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DECOMPOSE_H +#define ONNX_MLIR_DECOMPOSE_H #include "mlir/IR/PatternMatch.h" @@ -28,3 +29,4 @@ namespace onnx_mlir { void getDecomposeONNXToONNXPatterns(mlir::RewritePatternSet &patterns); } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/Transforms/DecomposeEinsum.hpp b/src/Dialect/ONNX/Transforms/DecomposeEinsum.hpp index 8caa964b57..76ec382ac6 100644 --- a/src/Dialect/ONNX/Transforms/DecomposeEinsum.hpp +++ b/src/Dialect/ONNX/Transforms/DecomposeEinsum.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DECOMPOSE_EINSUM_H +#define ONNX_MLIR_DECOMPOSE_EINSUM_H #include "mlir/IR/PatternMatch.h" @@ -28,3 +29,4 @@ class DecomposeEinsumPattern }; } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp index 3f23b1242a..ad88bceb42 100644 --- a/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp +++ b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_ostream.h" +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Interface/ShapeInferenceOpInterface.hpp" @@ -48,9 +49,16 @@ class InstrumentONNXSignaturePass InstrumentONNXSignaturePass() = default; InstrumentONNXSignaturePass(const InstrumentONNXSignaturePass &pass) : mlir::PassWrapper>() {} + OperationPass>() { + signaturePattern = pass.signaturePattern; + } + InstrumentONNXSignaturePass(const std::string pattern) { + signaturePattern = pattern; + } private: + std::string signaturePattern; + public: StringRef getArgument() const override { return "instrument-onnx-runtime-signature"; @@ -62,25 +70,31 @@ class InstrumentONNXSignaturePass } void runOnOperation() override { + onnx_mlir::EnableByRegexOption traceSpecificOpPattern( + /*emptyIsNone*/ false); + traceSpecificOpPattern.setRegexString(signaturePattern); // Iterate on the operations nested in this function. getOperation().walk([&](mlir::Operation *op) { - if (isa(op->getDialect())) { - if (!isa(op)) { - Location loc = op->getLoc(); - OpBuilder builder(op); - std::string opName = op->getName().getStringRef().str(); - std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op); - std::string fullName = opName + ", " + nodeName; - StringAttr fullNameAttr = builder.getStringAttr(fullName); - // Enqueue all input operands, and then the results. - llvm::SmallVector operAndRes(op->getOperands()); - for (Value res : op->getResults()) - operAndRes.emplace_back(res); - // Since we may use the result of an operation, we must insert the - // print operation after the operation. - builder.setInsertionPointAfter(op); - builder.create(loc, fullNameAttr, operAndRes); - } + std::string opName = op->getName().getStringRef().str(); + auto dialect = op->getDialect(); + if (isa(dialect) || isa(op)) { + // Always skip function dialects (such as function call/return), as well + // as ONNX print signature ops. + } else if (traceSpecificOpPattern.isEnabled(opName)) { + // Add signature printing op. + Location loc = op->getLoc(); + OpBuilder builder(op); + std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op); + std::string fullName = opName + ", " + nodeName; + StringAttr fullNameAttr = builder.getStringAttr(fullName); + // Enqueue all input operands, and then the results. + llvm::SmallVector operAndRes(op->getOperands()); + for (Value res : op->getResults()) + operAndRes.emplace_back(res); + // Since we may use the result of an operation, we must insert the + // print operation after the operation. + builder.setInsertionPointAfter(op); + builder.create(loc, fullNameAttr, operAndRes); } }); } @@ -90,6 +104,7 @@ class InstrumentONNXSignaturePass /*! * Create an instrumentation pass. */ -std::unique_ptr onnx_mlir::createInstrumentONNXSignaturePass() { - return std::make_unique(); +std::unique_ptr onnx_mlir::createInstrumentONNXSignaturePass( + const std::string pattern) { + return std::make_unique(pattern); } diff --git a/src/Dialect/ONNX/Transforms/Recompose.hpp b/src/Dialect/ONNX/Transforms/Recompose.hpp index 2f01036484..ab78b3583b 100644 --- a/src/Dialect/ONNX/Transforms/Recompose.hpp +++ b/src/Dialect/ONNX/Transforms/Recompose.hpp @@ -4,7 +4,7 @@ //===----------- ONNXRecompose.hpp - ONNX High Level Rewriting ------------===// // -// Copyright 2023 The IBM Research Authors. +// Copyright 2023-2024 The IBM Research Authors. // // ============================================================================= // @@ -17,7 +17,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_RECOMPOSE_H +#define ONNX_MLIR_RECOMPOSE_H #include "mlir/IR/PatternMatch.h" @@ -28,3 +29,4 @@ namespace onnx_mlir { void getRecomposeONNXToONNXPatterns(mlir::RewritePatternSet &patterns); } // namespace onnx_mlir +#endif diff --git a/src/Dialect/ONNX/Transforms/ShapeInference.hpp b/src/Dialect/ONNX/Transforms/ShapeInference.hpp index 5911eee5b8..d07b36fa34 100644 --- a/src/Dialect/ONNX/Transforms/ShapeInference.hpp +++ b/src/Dialect/ONNX/Transforms/ShapeInference.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_SHAPE_INFERENCE_H +#define ONNX_MLIR_SHAPE_INFERENCE_H #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" @@ -26,3 +27,4 @@ void getShapeInferencePatterns(mlir::RewritePatternSet &set); void inferFunctionReturnShapes(mlir::func::FuncOp f); } // namespace onnx_mlir +#endif diff --git a/src/Interface/HasOnnxSubgraphOpInterface.hpp b/src/Interface/HasOnnxSubgraphOpInterface.hpp index 88a51712f4..6bd2b0dd6b 100644 --- a/src/Interface/HasOnnxSubgraphOpInterface.hpp +++ b/src/Interface/HasOnnxSubgraphOpInterface.hpp @@ -5,7 +5,7 @@ //===------------------- HasOnnxSubgraphOpInterface.hpp ------------------===// //===------------- Has Onnx Subgraph Op Interface Definition -------------===// // -// Copyright 2020 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_HAS_ONNX_INTERFACE_H +#define ONNX_MLIR_HAS_ONNX_INTERFACE_H #include #include @@ -26,3 +27,4 @@ namespace mlir { #include "src/Interface/HasOnnxSubgraphOpInterface.hpp.inc" } // end namespace mlir +#endif diff --git a/src/Interface/ResultTypeInferenceOpInterface.hpp b/src/Interface/ResultTypeInferenceOpInterface.hpp index 166e6364c3..4455f3138f 100644 --- a/src/Interface/ResultTypeInferenceOpInterface.hpp +++ b/src/Interface/ResultTypeInferenceOpInterface.hpp @@ -5,7 +5,7 @@ //===------------ ResultTypeInferenceOpInterface.hpp --------------===// //===------- Infer Data Type for Result of Op Interface Definition -------===// // -// Copyright 2020 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -14,7 +14,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_RESULT_TYPE_INFERENCE_H +#define ONNX_MLIR_RESULT_TYPE_INFERENCE_H #include #include @@ -27,3 +28,4 @@ namespace mlir { #include "src/Interface/ResultTypeInferenceOpInterface.hpp.inc" } // end namespace mlir +#endif diff --git a/src/Interface/ShapeHelperOpInterface.hpp b/src/Interface/ShapeHelperOpInterface.hpp index bbe3398da9..0e7ec891c8 100644 --- a/src/Interface/ShapeHelperOpInterface.hpp +++ b/src/Interface/ShapeHelperOpInterface.hpp @@ -4,7 +4,7 @@ //===-------- ShapeHelperOpInterface.hpp - Definition for ShapeHelper -----===// // -// Copyright 2023 The IBM Research Authors. +// Copyright 2023-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_SHAPE_HELPER_INFERENCE_H +#define ONNX_MLIR_SHAPE_HELPER_INFERENCE_H #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" @@ -26,8 +27,6 @@ namespace onnx_mlir { -using DimsExpr = llvm::SmallVector; - struct ONNXOpShapeHelper { /* Constructor for shape inference. @@ -176,3 +175,4 @@ struct ONNXOpShapeHelper { /// Include the auto-generated declarations. #include "src/Interface/ShapeHelperOpInterface.hpp.inc" +#endif diff --git a/src/Interface/ShapeInferenceOpInterface.hpp b/src/Interface/ShapeInferenceOpInterface.hpp index 76bb523b47..0055472720 100644 --- a/src/Interface/ShapeInferenceOpInterface.hpp +++ b/src/Interface/ShapeInferenceOpInterface.hpp @@ -4,7 +4,7 @@ //===---- ShapeInferenceOpInterface.hpp - Definition for ShapeInference ---===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,10 +13,12 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_SHAPE_INFERENCE_INTERFACE_H +#define ONNX_MLIR_SHAPE_INFERENCE_INTERFACE_H #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpDefinition.h" /// Include the auto-generated declarations. #include "src/Interface/ShapeInferenceOpInterface.hpp.inc" +#endif diff --git a/src/Interface/SpecializedKernelOpInterface.hpp b/src/Interface/SpecializedKernelOpInterface.hpp index f252b32605..1d8459d738 100644 --- a/src/Interface/SpecializedKernelOpInterface.hpp +++ b/src/Interface/SpecializedKernelOpInterface.hpp @@ -2,18 +2,20 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===------------------- HasOnnxSubgraphOpInterface.hpp ------------------===// -//===------------- Has Onnx Subgraph Op Interface Definition -------------===// +//===------------------- SpecializedKernelOpInterface.hpp +//------------------===// +//===------------- Specialized Kernel Op Interface Definition -------------===// // -// Copyright 2020 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // -// This file contains the declaration of the HasOnnxSubgraph Op Interface. +// This file contains the declaration of the SpecializedKernel Op Interface. // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_SPECIALIZED_KERNEL_INTERFACE_H +#define ONNX_MLIR_SPECIALIZED_KERNEL_INTERFACE_H #include #include @@ -26,3 +28,4 @@ namespace mlir { #include "src/Interface/SpecializedKernelOpInterface.hpp.inc" } // end namespace mlir +#endif diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 819f2845c4..166a19217d 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_PASSES_H +#define ONNX_MLIR_PASSES_H #include #include @@ -59,7 +60,8 @@ std::unique_ptr createInstrumentPass( /// Passes for instrumenting the ONNX ops to print their operand type /// signatures at runtime. -std::unique_ptr createInstrumentONNXSignaturePass(); +std::unique_ptr createInstrumentONNXSignaturePass( + const std::string pattern); /// Pass for simplifying shape-related ONNX operations. std::unique_ptr createSimplifyShapeRelatedOpsPass(); @@ -122,3 +124,4 @@ std::unique_ptr createConvertKrnlToLLVMPass(bool verifyInputTensors, std::unique_ptr createConvertONNXToTOSAPass(); } // namespace onnx_mlir +#endif \ No newline at end of file diff --git a/src/Runtime/ExecutionSession.hpp b/src/Runtime/ExecutionSession.hpp index b926b24b31..856a1005cc 100644 --- a/src/Runtime/ExecutionSession.hpp +++ b/src/Runtime/ExecutionSession.hpp @@ -4,7 +4,7 @@ //===--------- ExecutionSession.hpp - ExecutionSession Declaration --------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_EXECUTION_SESSION_H +#define ONNX_MLIR_EXECUTION_SESSION_H #include #include @@ -120,3 +121,4 @@ class ExecutionSession { signatureFuncType _outputSignatureFunc = nullptr; }; } // namespace onnx_mlir +#endif diff --git a/src/Runtime/OMIndexLookup.inc b/src/Runtime/OMIndexLookup.inc index 255d856fa2..9e32c810ec 100644 --- a/src/Runtime/OMIndexLookup.inc +++ b/src/Runtime/OMIndexLookup.inc @@ -34,7 +34,8 @@ static inline uint32_t hash_string(uint32_t hval, const char *str) { // Adaptation of 32-bit FNV for int64_t values. static inline uint32_t hash_int64(uint32_t hval, int64_t val) { char str[20]; - snprintf(str, sizeof(str), "%lld", (long long)val); + int num_chars_written = snprintf(str, sizeof(str), "%lld", (long long)val); + assert(num_chars_written >= 0 && "snprintf write error to str"); return hash_string(hval, str); } diff --git a/src/Runtime/OMInstrument.inc b/src/Runtime/OMInstrument.inc index a875b4a314..cb87bc6bc7 100644 --- a/src/Runtime/OMInstrument.inc +++ b/src/Runtime/OMInstrument.inc @@ -146,7 +146,8 @@ static void ReportMemory() { char memOutput[200]; FILE *memPipe; mypid = getpid(); - snprintf(memCommand, sizeof(memCommand), "ps -o vsz='' -p %d", mypid); + int num_chars_written = snprintf(memCommand, sizeof(memCommand), "ps -o vsz='' -p %d", mypid); + assert(num_chars_written >= 0 && "snprintf write error to memCommand"); memPipe = popen(memCommand, "r"); if (!memPipe) { fprintf(fout, ", error-failed-to-execute-ps\n"); diff --git a/src/Runtime/OMResize.inc b/src/Runtime/OMResize.inc index 9e99cb0893..9e754cb8f4 100644 --- a/src/Runtime/OMResize.inc +++ b/src/Runtime/OMResize.inc @@ -161,6 +161,7 @@ static float interpolate_1d_with_x(OMTensor *data, float scale_factor, float x, // float points[coeffs_n]; float *points = (float *)malloc(sizeof(float) * coeffs_n); + assert(points && "failed to allocate memory for points"); get_neighbor(x_ori, n, input_width, (float *)omTensorGetDataPtr(data), points, exclude_outside); @@ -184,6 +185,7 @@ static float interpolate_nd_with_x(OMTensor *data, int n, float *scale_factors, } else { int64_t input_width = omTensorGetShape(data)[0]; float *tempData = (float *)malloc(sizeof(float) * input_width); + assert(tempData && "failed to allocate memory for tempData"); int64_t tempShape[] = {input_width}; int64_t stride = 1; @@ -231,6 +233,7 @@ static void generate_coordinates( int64_t rank, int64_t *output_size, int64_t *allCoordinates) { int64_t position = 0; int64_t *currentIter = (int64_t *)malloc(sizeof(int64_t) * rank); + assert(currentIter && "failed to allocate memory for currentIter"); coordinate_step(rank, output_size, allCoordinates, 0, currentIter, &position); free(currentIter); } @@ -253,11 +256,13 @@ static void interpolate_nd_OMTensor(OMTensor *output_OMT, OMTensor *data, output_size = (int64_t *)omTensorGetDataPtr(output_size_OMT); if (scale_factor == NULL) { scale_factor = (float *)malloc(sizeof(float) * rank); + assert(scale_factor && "failed to allocate memory for scale_factor"); for (int i = 0; i < rank; i++) { scale_factor[i] = ((float)output_size[i]) / inputShape[i]; } } else { output_size = (int64_t *)malloc(sizeof(int64_t) * rank); + assert(output_size && "failed to allocate memory for output_size"); for (int i = 0; i < rank; i++) { output_size[i] = scale_factor[i] * inputShape[i]; } @@ -272,13 +277,16 @@ static void interpolate_nd_OMTensor(OMTensor *output_OMT, OMTensor *data, // int64_t allCoordinates[outputSize][rank]; int64_t *allCoordinates = (int64_t *)malloc(outputSize * rank * sizeof(int64_t)); + assert(allCoordinates && "failed to allocate memory for allCoordinates"); generate_coordinates(rank, output_size, allCoordinates); // float coeffs_buffer[coeffs_n]; // = {1.0, 0.}; float *coeffs_buffer = (float *)malloc(sizeof(float) * coeffs_n); + assert(coeffs_buffer && "failed to allocate memory for coeffs_buffer"); for (int i = 0; i < outputSize; i++) { float *Xs = (float *)malloc(sizeof(float) * rank); + assert(Xs && "failed to allocate memory for Xs"); for (int j = 0; j < rank; j++) { Xs[j] = *(allCoordinates + i * rank + j); } diff --git a/src/Runtime/OMSort.inc b/src/Runtime/OMSort.inc index 9c8da29d05..fea3252751 100644 --- a/src/Runtime/OMSort.inc +++ b/src/Runtime/OMSort.inc @@ -54,9 +54,9 @@ // #define compareFunctionBody(typeName, direction, load, dataPtr, idx1p, idx2p) \ { \ - uint64_t idx1 = *((uint64_t *)idx1p); \ - uint64_t idx2 = *((uint64_t *)idx2p); \ - typeName *data = (typeName *)dataPtr; \ + uint64_t idx1 = *((uint64_t *)(idx1p)); \ + uint64_t idx2 = *((uint64_t *)(idx2p)); \ + typeName *data = (typeName *)(dataPtr); \ load(typeName, v1, data[idx1]); \ load(typeName, v2, data[idx2]); \ return (direction(v1, v2) || (v1 == v2 && idx1 < idx2)) ? -1 : 1; \ @@ -89,7 +89,7 @@ typedef int( #pragma GCC diagnostic ignored "-Wcast-qual" #endif -#define Load(typeName, to, from) typeName to = from +#define Load(typeName, to, from) typeName (to) = (from) // Convert f16 elements to f32 for comparison because we don't have logic to // compare f16 elements directly on all platforms. @@ -97,7 +97,7 @@ typedef int( // Or consider converting the whole tensor to f32, sort as f32, and then // convert the sorted tensor back to f16. That may be faster than // converting elements for each comparison during sorting. -#define LoadF16AsF32(typeName, to, from) float to = om_f16_to_f32(from) +#define LoadF16AsF32(typeName, to, from) float (to) = om_f16_to_f32(from) // declare ascending functions #define Ascending(lhs, rhs) ((lhs) < (rhs)) @@ -169,35 +169,35 @@ typedef struct indexStack { #define STACK_INIT(stack, stackSize) \ do { \ - assert(stackSize > 0); \ - stack.stackData = (uint64_t *)alloca(stackSize * sizeof(uint64_t)); \ - assert(stack.stackData != NULL); \ - stack.stackSize = stackSize; \ - stack.stackTop = 0; \ + assert((stackSize) > 0); \ + (stack).stackData = (uint64_t *)alloca((stackSize) * sizeof(uint64_t)); \ + assert((stack).stackData != NULL); \ + (stack).stackSize = (stackSize); \ + (stack).stackTop = 0; \ } while (0) -#define STACK_ISEMPTY(stack) (stack.stackTop == 0) +#define STACK_ISEMPTY(stack) ((stack).stackTop == 0) #define STACK_PUSH(stack, begin, end) \ do { \ - assert(stack.stackTop <= stack.stackSize - 2); \ - stack.stackData[(stack.stackTop)++] = begin; \ - stack.stackData[(stack.stackTop)++] = end; \ + assert((stack).stackTop <= (stack).stackSize - 2); \ + (stack).stackData[((stack).stackTop)++] = (begin); \ + (stack).stackData[((stack).stackTop)++] = (end); \ } while (0) #define STACK_POP(stack, begin, end) \ do { \ - assert(stack.stackTop >= 2); \ - end = stack.stackData[--(stack.stackTop)]; \ - begin = stack.stackData[--(stack.stackTop)]; \ + assert((stack).stackTop >= 2); \ + (end) = (stack).stackData[--((stack).stackTop)]; \ + (begin) = (stack).stackData[--((stack).stackTop)]; \ } while (0) #define STACK_PRINT(stack) \ do { \ - assert(stack.stackTop >= 0); \ + assert((stack).stackTop >= 0); \ fprintf(stderr, "Stack: ["); \ - for (int64_t i = 0; (i + 1) < stack.stackTop; i += 2) { \ + for (int64_t i = 0; (i + 1) < (stack).stackTop; i += 2) { \ fprintf( \ - stderr, "<%ld:%ld>, ", stack.stackData[i], stack.stackData[i + 1]); \ + stderr, "<%ld:%ld>, ", (stack).stackData[i], (stack).stackData[i + 1]);\ } \ fprintf( \ - stderr, "] (Top=%ld,Size=%ld)\n", stack.stackTop, stack.stackSize); \ + stderr, "] (Top=%ld,Size=%ld)\n", (stack).stackTop, (stack).stackSize);\ fflush(stderr); \ } while (0) diff --git a/src/Runtime/OMTensor.inc b/src/Runtime/OMTensor.inc index 19eb594b60..cdc2fd05f3 100644 --- a/src/Runtime/OMTensor.inc +++ b/src/Runtime/OMTensor.inc @@ -480,10 +480,10 @@ static void printData(FILE *fout, const OMTensor *tensor) { /* Helper macros to print data for 1-4D tensors */ #define LOOP_1(INDEX, IV, UB) \ fprintf(fout, "["); \ - for (int64_t IV = 0; IV < UB; ++IV) { \ + for (int64_t (IV) = 0; (IV) < (UB); ++(IV)) { \ if (IV) \ fprintf(fout, ", "); \ - indexes[INDEX] = IV; \ + indexes[(INDEX)] = (IV); \ int64_t elemOffset = computeElemOffset(tensor->_strides, indexes, rank); \ printElement(fout, dataPtr, elemOffset, dataType); \ } \ @@ -491,31 +491,31 @@ static void printData(FILE *fout, const OMTensor *tensor) { #define LOOP_2(INDEX, IV, UB, ...) \ fprintf(fout, "["); \ - for (int64_t IV = 0; IV < UB; ++IV) { \ + for (int64_t (IV) = 0; (IV) < (UB); ++(IV)) { \ if (IV) \ fprintf(fout, ", "); \ - indexes[INDEX] = IV; \ - LOOP_1(INDEX + 1, __VA_ARGS__) \ + indexes[(INDEX)] = (IV); \ + LOOP_1((INDEX) + 1, __VA_ARGS__) \ } \ fprintf(fout, "]"); #define LOOP_3(INDEX, IV, UB, ...) \ fprintf(fout, "["); \ - for (int64_t IV = 0; IV < UB; ++IV) { \ + for (int64_t (IV) = 0; (IV) < (UB); ++(IV)) { \ if (IV) \ fprintf(fout, ", "); \ - indexes[INDEX] = IV; \ - LOOP_2(INDEX + 1, __VA_ARGS__) \ + indexes[(INDEX)] = (IV); \ + LOOP_2((INDEX) + 1, __VA_ARGS__) \ } \ fprintf(fout, "]"); #define LOOP_4(INDEX, IV, UB, ...) \ fprintf(fout, "["); \ - for (int64_t IV = 0; IV < UB; ++IV) { \ + for (int64_t (IV) = 0; (IV) < (UB); ++(IV)) { \ if (IV) \ fprintf(fout, ", "); \ - indexes[INDEX] = IV; \ - LOOP_3(INDEX + 1, __VA_ARGS__) \ + indexes[(INDEX)] = (IV); \ + LOOP_3((INDEX) + 1, __VA_ARGS__) \ } \ fprintf(fout, "]"); diff --git a/src/Runtime/OMTensorHelper.hpp b/src/Runtime/OMTensorHelper.hpp index ec922b3cbf..c41690c517 100644 --- a/src/Runtime/OMTensorHelper.hpp +++ b/src/Runtime/OMTensorHelper.hpp @@ -4,7 +4,7 @@ //===---------- OMTensorHelper.hpp - OMTensor Helper Func header ----------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -14,7 +14,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_TENSOR_HELPER_H +#define ONNX_MLIR_TENSOR_HELPER_H #include #include @@ -180,3 +181,4 @@ std::vector> omTensorComputeIndexSet(const OMTensor *omt); template bool omTensorAreTwoOmtsClose( const OMTensor *a, const OMTensor *b, float rtol = 1e-5, float atol = 1e-5); +#endif \ No newline at end of file diff --git a/src/Runtime/OMTensorListHelper.hpp b/src/Runtime/OMTensorListHelper.hpp index f59d0362ec..b8a059ff8c 100644 --- a/src/Runtime/OMTensorListHelper.hpp +++ b/src/Runtime/OMTensorListHelper.hpp @@ -4,7 +4,7 @@ //===----- OMTensorListHelper.hpp - OMTensor List Helper Func header ------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_TENSOR_LIST_HELPER_H +#define ONNX_MLIR_TENSOR_LIST_HELPER_H #include "OnnxMlirRuntime.h" @@ -23,3 +24,4 @@ * mechanism. */ void omTensorListDestroyShallow(OMTensorList *list); +#endif diff --git a/src/Runtime/jni/jnilog.c b/src/Runtime/jni/jnilog.c index 9abbc856ba..0602f2bb77 100644 --- a/src/Runtime/jni/jnilog.c +++ b/src/Runtime/jni/jnilog.c @@ -40,7 +40,7 @@ static pthread_key_t log_inited; static pthread_key_t log_level; static pthread_key_t log_fp; -#define THREAD_LOCAL_INIT(key, func) pthread_once(key, func) +#define THREAD_LOCAL_INIT(key, func) pthread_once((key), (func)) INLINE void key_init() { pthread_key_create(&log_inited, NULL); @@ -160,11 +160,14 @@ void log_printf(int level, const char *file, const char *func, int line, time_t now; struct tm *tm; char buf[LOG_MAX_LEN]; + int num_chars_written = 0; /* Get local time and format as 2020-07-03 05:17:42 -0400 */ if (time(&now) == -1 || (tm = localtime(&now)) == NULL || - strftime(buf, sizeof(buf), "[%F %T %z]", tm) == 0) - sprintf(buf, "[-]"); + strftime(buf, sizeof(buf), "[%F %T %z]", tm) == 0) { + num_chars_written = sprintf(buf, "[-]"); + assert(num_chars_written >= 0 && "sprintf write error to buf"); + } /* Output thread ID, log level, file name, function number, and line number. * Note that pthread_t on most platforms is unsigned long but is a struct @@ -172,9 +175,10 @@ void log_printf(int level, const char *file, const char *func, int line, */ pthread_t tid = get_threadid(); assert(LOG_MAX_LEN >= strlen(buf) && "error in snprintf length"); - snprintf(buf + strlen(buf), LOG_MAX_LEN - strlen(buf), + num_chars_written = snprintf(buf + strlen(buf), LOG_MAX_LEN - strlen(buf), "[%016lx][%s]%s:%s:%d ", *(unsigned long *)&tid, log_level_name[level], get_filename(file), func, line); + assert(num_chars_written >= 0 && "snprintf write error to buf"); /* Output actual log data */ /* Definition of vsnprintf: @@ -203,11 +207,15 @@ void log_printf(int level, const char *file, const char *func, int line, va_list log_data; va_start(log_data, fmt); - vsnprintf(buf + strlen(buf), LOG_MAX_LEN - strlen(buf), fmt, log_data); + num_chars_written = + vsnprintf(buf + strlen(buf), LOG_MAX_LEN - strlen(buf), fmt, log_data); + assert(num_chars_written >= 0 && "vsnprintf write error to buf"); va_end(log_data); /* Add new line */ - snprintf(buf + strlen(buf), LOG_MAX_LEN - strlen(buf), "\n"); + num_chars_written = + snprintf(buf + strlen(buf), LOG_MAX_LEN - strlen(buf), "\n"); + assert(num_chars_written >= 0 && "snprintf write error to buf"); /* Write out and flush the output buffer */ FILE *fp = get_log_fp(); @@ -238,9 +246,11 @@ static FILE *get_log_file_by_name(char *name) { char *tname = (char *)malloc(strlen(name) + 32); if (tname) { pthread_t tid = get_threadid(); - snprintf( + int num_chars_written = snprintf( tname, strlen(name) + 32, "%s.%016lx", name, *(unsigned long *)&tid); + assert(num_chars_written >= 0 && "snprintf write error to tname"); fp = fopen(tname, "w"); + assert(fp != NULL && "fopen error on tname"); free(tname); } } diff --git a/src/Runtime/jni/jnilog.h b/src/Runtime/jni/jnilog.h index 9968bf34b3..68e79299e9 100644 --- a/src/Runtime/jni/jnilog.h +++ b/src/Runtime/jni/jnilog.h @@ -23,24 +23,25 @@ enum { LOG_TRACE, LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL }; #define LOG_MAX_LEN 4096 /* max number of chars to output */ #define LOG_MAX_NUM 128 /* max number of elements to output */ -#define MIN(x, y) ((x) > (y) ? y : x) +#define MIN(x, y) ((x) > (y) ? (y) : (x)) /* Construct string of up to LOG_MAX_NUM elements of an array of C type * To avoid variable name clash, prefix with double underscores. */ #define LOG_BUF_C_TYPE(type, format, buf, data, n) \ do { \ - char *__p = buf; \ + char *__p = (buf); \ /* Reserve 5 char at the end for " ... \0". Note the first \ * space will come from the '\0' of the previous string. \ */ \ - int __i = 0, __j = sizeof(buf) - 5, __k, __l = MIN(n, LOG_MAX_NUM); \ + int __i = 0, __j = sizeof(buf) - 5, __k, __l = MIN((n), LOG_MAX_NUM); \ /* j is the available number of chars including '\0'. k is the \ * number of chars printed without '\0'. So as long as k < j, \ * it means the output, with a trailing '\0', fits in the buffer. \ */ \ - while (__i < __l && \ - (__k = snprintf(__p, __j, format, ((type *)data)[__i])) < __j) { \ + while (__i < __l && (__k = snprintf(__p, __j, (format), \ + ((type *)(data))[__i])) < __j) { \ + assert(__k >= 0 && "snprintf write error to __p"); \ __p += __k; \ __j -= __k; \ __i++; \ @@ -55,8 +56,9 @@ enum { LOG_TRACE, LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL }; * add "... " at the end to denote that the last element is \ * truncated. \ */ \ - snprintf(buf + strlen(buf), 6, \ + int __m = snprintf(buf + strlen(buf), 6, \ (__i == __l) ? " " : (__j == 1) ? " ... " : "... "); \ + assert(__m >= 0 && "snprintf write error to buf"); \ } while (0) /* Construct string of up to LOG_MAX_NUM elements of an array of ONNX type. @@ -67,64 +69,73 @@ enum { LOG_TRACE, LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL }; switch (type) { \ case ONNX_TYPE_UINT8: \ case ONNX_TYPE_INT8: \ - LOG_BUF_C_TYPE(const char, hex ? " %02x" : "%c", buf, data, n); \ + LOG_BUF_C_TYPE(const char, (hex) ? " %02x" : "%c", (buf), (data), (n)); \ break; \ case ONNX_TYPE_UINT16: \ case ONNX_TYPE_INT16: \ - LOG_BUF_C_TYPE(const short, hex ? " %04x" : " %d", buf, data, n); \ + LOG_BUF_C_TYPE( \ + const short, (hex) ? " %04x" : " %d", (buf), (data), (n)); \ break; \ case ONNX_TYPE_UINT32: \ case ONNX_TYPE_INT32: \ - LOG_BUF_C_TYPE(const int, hex ? " %08x" : " %d", buf, data, n); \ + LOG_BUF_C_TYPE(const int, (hex) ? " %08x" : " %d", (buf), (data), (n)); \ break; \ case ONNX_TYPE_UINT64: \ case ONNX_TYPE_INT64: \ - LOG_BUF_C_TYPE(const long, hex ? " %016x" : " %ld", buf, data, n); \ + LOG_BUF_C_TYPE( \ + const long, (hex) ? " %016x" : " %ld", (buf), (data), (n)); \ break; \ case ONNX_TYPE_FLOAT16: \ - LOG_BUF_C_TYPE(const short, " %04x", buf, data, n); \ + LOG_BUF_C_TYPE(const short, " %04x", (buf), (data), (n)); \ break; \ case ONNX_TYPE_FLOAT: \ - LOG_BUF_C_TYPE(const float, hex ? " %08x" : " %f", buf, data, n); \ + LOG_BUF_C_TYPE( \ + const float, (hex) ? " %08x" : " %f", (buf), (data), (n)); \ break; \ case ONNX_TYPE_DOUBLE: \ - LOG_BUF_C_TYPE(const double, hex ? " %016x" : " %lf", buf, data, n); \ + LOG_BUF_C_TYPE( \ + const double, (hex) ? " %016x" : " %lf", (buf), (data), (n)); \ break; \ - default: \ - sprintf(buf, " unsupported data type %d ", type); \ + default: { \ + int __a = sprintf((buf), " unsupported data type %d ", (type)); \ + assert(__a >= 0 && "sprintf write error to buf"); \ + } \ } \ } while (0) -#define LOG_BUF(type, buf, data, n) LOG_BUF_ONNX_TYPE(type, buf, data, n, 0) -#define LOG_XBUF(type, buf, data, n) LOG_BUF_ONNX_TYPE(type, buf, data, n, 1) +#define LOG_BUF(type, buf, data, n) \ + LOG_BUF_ONNX_TYPE((type), (buf), (data), (n), 0) +#define LOG_XBUF(type, buf, data, n) \ + LOG_BUF_ONNX_TYPE((type), (buf), (data), (n), 1) #define LOG_CHAR_BUF(buf, data, n) \ - LOG_BUF_C_TYPE(const char, "%c", buf, data, n) + LOG_BUF_C_TYPE(const char, "%c", (buf), (data), (n)) #define LOG_CHAR_XBUF(buf, data, n) \ - LOG_BUF_C_TYPE(const char, " %02x", buf, data, n) + LOG_BUF_C_TYPE(const char, " %02x", (buf), (data), (n)) #define LOG_SHORT_BUF(buf, data, n) \ - LOG_BUF_C_TYPE(const short, " %d", buf, data, n) + LOG_BUF_C_TYPE(const short, " %d", (buf), (data), (n)) #define LOG_SHORT_XBUF(buf, data, n) \ - LOG_BUF_C_TYPE(const short, " %04x", buf, data, n) -#define LOG_INT_BUF(buf, data, n) LOG_BUF_C_TYPE(const int, " %d", buf, data, n) + LOG_BUF_C_TYPE(const short, " %04x", (buf), (data), (n)) +#define LOG_INT_BUF(buf, data, n) \ + LOG_BUF_C_TYPE(const int, " %d", (buf), (data), (n)) #define LOG_INT_XBUF(buf, data, n) \ - LOG_BUF_C_TYPE(const int, " %08x", buf, data, n) + LOG_BUF_C_TYPE(const int, " %08x", (buf), (data), (n)) #define LOG_LONG_BUF(buf, data, n) \ - LOG_BUF_C_TYPE(const long, " %ld", buf, data, n) + LOG_BUF_C_TYPE(const long, " %ld", (buf), (data), (n)) #define LOG_LONG_XBUF(buf, data, n) \ - LOG_BUF_C_TYPE(const long, " %016x", buf, data, n) + LOG_BUF_C_TYPE(const long, " %016x", (buf), (data), (n)) #define LOG_FLOAT_BUF(buf, data, n) \ - LOG_BUF_C_TYPE(const float, " %f", buf, data, n) + LOG_BUF_C_TYPE(const float, " %f", (buf), (data), (n)) #define LOG_FLOAT_XBUF(buf, data, n) \ - LOG_BUF_C_TYPE(const float, " %08x", buf, data, n) + LOG_BUF_C_TYPE(const float, " %08x", (buf), (data), (n)) #define LOG_DOUBLE_BUF(buf, data, n) \ - LOG_BUF_C_TYPE(const double, " %lf", buf, data, n) + LOG_BUF_C_TYPE(const double, " %lf", (buf), (data), (n)) #define LOG_DOUBLE_XBUF(buf, data, n) \ - LOG_BUF_C_TYPE(const double, " %016x", buf, data, n) + LOG_BUF_C_TYPE(const double, " %016x", (buf), (data), (n)) /* Main macro for log output */ #define LOG_PRINTF(level, ...) \ - log_printf(level, __FILE__, __FUNCTION__, __LINE__, __VA_ARGS__) + log_printf((level), __FILE__, __FUNCTION__, __LINE__, __VA_ARGS__) /* Generic log routine */ extern void log_init(void); diff --git a/src/Runtime/jni/jniwrapper.c b/src/Runtime/jni/jniwrapper.c index a9e6c47838..77213a51d8 100644 --- a/src/Runtime/jni/jniwrapper.c +++ b/src/Runtime/jni/jniwrapper.c @@ -37,7 +37,7 @@ extern OMTensorList *run_main_graph(OMTensorList *); * this call simply returns NULL. */ #define CHECK_CALL(type, var, call, success, ...) \ - type var = call; \ + type(var) = (call); \ do { \ if (!(success)) { \ LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \ @@ -63,7 +63,7 @@ extern OMTensorList *run_main_graph(OMTensorList *); } else if (!(success)) { \ LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \ if (ecpt) \ - (*env)->ThrowNew(env, ecpt, jnistr[MSG_JNI_CALL_ERROR]); \ + (*env)->ThrowNew((env), (ecpt), jnistr[MSG_JNI_CALL_ERROR]); \ return NULL; \ } \ } while (0) @@ -72,13 +72,13 @@ extern OMTensorList *run_main_graph(OMTensorList *); * log error and throw Java exception if the call failed. */ #define JNI_VAR_CALL(env, var, call, success, ecpt, ...) \ - JNI_CALL(env, var = call, success, ecpt, __VA_ARGS__) + JNI_CALL((env), (var) = (call), (success), (ecpt), __VA_ARGS__) /* Declare type var, make a JNI call and assign return value to var, * log error and throw Java exception if the call failed. */ #define JNI_TYPE_VAR_CALL(env, type, var, call, success, ecpt, ...) \ - JNI_CALL(env, type var = call, success, ecpt, __VA_ARGS__); + JNI_CALL((env), type(var) = (call), (success), (ecpt), __VA_ARGS__); /* Make a native library call, check success condition, * log error and throw Java exception if native code failed. @@ -91,7 +91,7 @@ extern OMTensorList *run_main_graph(OMTensorList *); if (!(success)) { \ LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \ if (ecpt) \ - (*env)->ThrowNew(env, ecpt, jnistr[MSG_NATIVE_CODE_ERROR]); \ + (*env)->ThrowNew((env), (ecpt), jnistr[MSG_NATIVE_CODE_ERROR]); \ return NULL; \ } \ } while (0) @@ -101,40 +101,40 @@ extern OMTensorList *run_main_graph(OMTensorList *); * Also check success condition. */ #define LIB_VAR_CALL(var, call, success, env, ecpt, ...) \ - LIB_CALL(var = call, success, env, ecpt, __VA_ARGS__); + LIB_CALL((var) = (call), (success), (env), (ecpt), __VA_ARGS__); /* Declare type var, make a native library call and assign * return value to var, log error and throw Java exception * if the call failed. Also check success condition. */ #define LIB_TYPE_VAR_CALL(type, var, call, success, env, ecpt, ...) \ - LIB_CALL(type var = call, success, env, ecpt, __VA_ARGS__); + LIB_CALL(type(var) = (call), (success), (env), (ecpt), __VA_ARGS__); /* Debug output of OMTensor fields */ #define OMT_DEBUG( \ i, n, data, shape, strides, dataType, bufferSize, rank, owning) \ do { \ char tmp[1024]; \ - LOG_BUF(dataType, tmp, data, n); \ - LOG_PRINTF(LOG_DEBUG, "omt[%d]:data=[%s]", i, tmp); \ + LOG_BUF((dataType), (tmp), (data), (n)); \ + LOG_PRINTF(LOG_DEBUG, "omt[%d]:(data)=[%s]", (i), (tmp)); \ tmp[0] = '\0'; \ - LOG_LONG_BUF(tmp, shape, rank); \ - LOG_PRINTF(LOG_DEBUG, "omt[%d]:shape=[%s]", i, tmp); \ - LOG_LONG_BUF(tmp, strides, rank); \ - LOG_PRINTF(LOG_DEBUG, "omt[%d]:strides=[%s]", i, tmp); \ - LOG_PRINTF(LOG_DEBUG, "omt[%d]:dataType=%d", i, dataType); \ - LOG_PRINTF(LOG_DEBUG, "omt[%d]:bufferSize=%ld", i, bufferSize); \ - LOG_PRINTF(LOG_DEBUG, "omt[%d]:rank=%ld", i, rank); \ - LOG_PRINTF(LOG_DEBUG, "omt[%d]:owning=%ld", i, owning); \ - LOG_PRINTF(LOG_DEBUG, "omt[%d]:numElems=%ld", i, n); \ + LOG_LONG_BUF(tmp, (shape), (rank)); \ + LOG_PRINTF(LOG_DEBUG, "omt[%d]:(shape)=[%s]", (i), tmp); \ + LOG_LONG_BUF(tmp, (strides), (rank)); \ + LOG_PRINTF(LOG_DEBUG, "omt[%d]:(strides)=[%s]", (i), tmp); \ + LOG_PRINTF(LOG_DEBUG, "omt[%d]:(dataType)=%d", (i), (dataType)); \ + LOG_PRINTF(LOG_DEBUG, "omt[%d]:(bufferSize)=%ld", (i), (bufferSize)); \ + LOG_PRINTF(LOG_DEBUG, "omt[%d]:(rank)=%ld", (i), (rank)); \ + LOG_PRINTF(LOG_DEBUG, "omt[%d]:(owning)=%ld", (i), (owning)); \ + LOG_PRINTF(LOG_DEBUG, "omt[%d]:(numElems)=%ld", (i), (n)); \ } while (0) /* Debug output of hex string */ #define HEX_DEBUG(label, string, n) \ do { \ char tmp[1024]; \ - LOG_CHAR_XBUF(tmp, string, n); \ - LOG_PRINTF(LOG_DEBUG, "%s(%d):[%s]", label, n, tmp); \ + LOG_CHAR_XBUF(tmp, (string), (n)); \ + LOG_PRINTF(LOG_DEBUG, "%s(%d):[%s]", (label), (n), tmp); \ } while (0) /* Java classes and methods needed for making various JNI API calls */ @@ -759,7 +759,8 @@ Java_com_ibm_onnxmlir_OMModel_query_1entry_1points(JNIEnv *env, jclass cls) { */ for (int64_t i = 0; i < neps; i++) { char ep[32]; - sprintf(ep, "ep[%lld]", (long long)i); + int num_chars_written = sprintf(ep, "ep[%lld]", (long long)i); + assert(num_chars_written >= 0 && "sprintf write error to ep"); HEX_DEBUG(ep, jni_eps[i], strlen(jni_eps[i])); LOG_PRINTF(LOG_DEBUG, "ep[%d](%ld):%s", i, strlen(jni_eps[i]), jni_eps[i]); diff --git a/src/Runtime/python/PyExecutionSession.hpp b/src/Runtime/python/PyExecutionSession.hpp index 6be915e95e..1114c0d8fe 100644 --- a/src/Runtime/python/PyExecutionSession.hpp +++ b/src/Runtime/python/PyExecutionSession.hpp @@ -4,7 +4,7 @@ //===------ PyExecutionSession.hpp - PyExecutionSession Declaration -------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_PY_EXECUTION_SESSION_H +#define ONNX_MLIR_PY_EXECUTION_SESSION_H #include "PyExecutionSessionBase.hpp" @@ -40,3 +41,4 @@ PYBIND11_MODULE(PyRuntimeC, m) { .def("output_signature", &onnx_mlir::PyExecutionSession::pyOutputSignature); } +#endif diff --git a/src/Runtime/python/PyExecutionSessionBase.hpp b/src/Runtime/python/PyExecutionSessionBase.hpp index fd01c1ad71..44c2a9f1aa 100644 --- a/src/Runtime/python/PyExecutionSessionBase.hpp +++ b/src/Runtime/python/PyExecutionSessionBase.hpp @@ -4,7 +4,7 @@ //===-- PyExecutionSessionBase.hpp - PyExecutionSessionBase Declaration ---===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_PY_EXECUTION_SESSION_BASE_H +#define ONNX_MLIR_PY_EXECUTION_SESSION_BASE_H #include #include @@ -53,3 +54,4 @@ class PyExecutionSessionBase std::string reportPythonError(std::string errorStr) const; }; } // namespace onnx_mlir +#endif diff --git a/src/Runtime/python/PyOMCompileExecutionSession.hpp b/src/Runtime/python/PyOMCompileExecutionSession.hpp index ac1291d003..b16bd3892e 100644 --- a/src/Runtime/python/PyOMCompileExecutionSession.hpp +++ b/src/Runtime/python/PyOMCompileExecutionSession.hpp @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_PY_OM_COMPILE_SESSION_BASE_H +#define ONNX_MLIR_PY_OM_COMPILE_SESSION_BASE_H #include #include @@ -74,3 +75,4 @@ PYBIND11_MODULE(PyCompileAndRuntimeC, m) { .def("output_signature", &onnx_mlir::PyOMCompileExecutionSession::pyOutputSignature); } +#endif diff --git a/src/Support/Arrays.hpp b/src/Support/Arrays.hpp index 955004e275..67aa50fdde 100644 --- a/src/Support/Arrays.hpp +++ b/src/Support/Arrays.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_ARRAYS_H +#define ONNX_MLIR_ARRAYS_H #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" @@ -69,4 +70,5 @@ llvm::MutableArrayRef castMutableArrayRef(llvm::MutableArrayRef a) { (a.size() * sizeof(Old)) / sizeof(New)); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir +#endif \ No newline at end of file diff --git a/src/Support/Common.hpp b/src/Support/Common.hpp index 9e91508cd0..5dab92ae6a 100644 --- a/src/Support/Common.hpp +++ b/src/Support/Common.hpp @@ -4,7 +4,7 @@ //====--------------- Common.hpp - Common Utilities -----------------------===// // -// Copyright 2021 The IBM Research Authors. +// Copyright 2021-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,10 +12,12 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_COMMON_H +#define ONNX_MLIR_COMMON_H #if defined(__GNUC__) || defined(__clang__) #define ATTRIBUTE(x) __attribute__((x)) #else #define ATTRIBUTE(x) #endif +#endif diff --git a/src/Support/Diagnostic.hpp b/src/Support/Diagnostic.hpp index a91ce48e1f..628abfe649 100644 --- a/src/Support/Diagnostic.hpp +++ b/src/Support/Diagnostic.hpp @@ -4,7 +4,7 @@ //====--------------- Diagnostic.hpp - Diagnostic Utilities ---------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_DIAGNOSTIC_H +#define ONNX_MLIR_DIAGNOSTIC_H #include "mlir/IR/Operation.h" #include "mlir/Support/LogicalResult.h" @@ -72,3 +73,4 @@ class Diagnostic { }; } // namespace onnx_mlir +#endif diff --git a/src/Support/KrnlSupport.hpp b/src/Support/KrnlSupport.hpp index 06cb94a761..e3f4ea1533 100644 --- a/src/Support/KrnlSupport.hpp +++ b/src/Support/KrnlSupport.hpp @@ -4,7 +4,7 @@ //====---------- KrnlSupport.hpp - Krnl-level support functions -----------===// // -// Copyright 2020 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_KRNL_SUPPORT_H +#define ONNX_MLIR_KRNL_SUPPORT_H #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -94,3 +95,4 @@ int64_t getAllocArgIndex(mlir::memref::AllocOp allocOp, int64_t index); int64_t getAllocAlignment(mlir::memref::AllocOp allocOp); } // namespace onnx_mlir +#endif diff --git a/src/Support/SmallFP.hpp b/src/Support/SmallFP.hpp index 175467f4fe..491ed0a30d 100644 --- a/src/Support/SmallFP.hpp +++ b/src/Support/SmallFP.hpp @@ -8,7 +8,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_SMALL_FP_H +#define ONNX_MLIR_SMALL_FP_H #include "mlir/IR/BuiltinAttributes.h" #include "src/Support/SmallFPConversion.h" @@ -253,3 +254,4 @@ struct mlir::DenseElementsAttr::is_valid_cpp_fp_type< onnx_mlir::float_8e5m2fnuz> { static constexpr bool value = true; }; +#endif diff --git a/src/Support/SmallFPConversion.c b/src/Support/SmallFPConversion.c index 24a3fbec3e..a1d6f12b1e 100644 --- a/src/Support/SmallFPConversion.c +++ b/src/Support/SmallFPConversion.c @@ -25,9 +25,9 @@ // // might violate the rules about strict aliasing in C++. #define BIT_CAST(TO_TYPE, TO, FROM) \ - TO_TYPE TO; \ + TO_TYPE(TO); \ static_assert(sizeof(TO) == sizeof(FROM), "only bit cast same sizes"); \ - memcpy(&TO, &FROM, sizeof(FROM)) + memcpy(&(TO), &(FROM), sizeof(FROM)) #if defined(__x86_64__) && defined(__F16C__) // On x86-64 build config -DCMAKE_CXX_FLAGS=-march=native defines __F16C__. diff --git a/src/Support/SmallVectorHelper.hpp b/src/Support/SmallVectorHelper.hpp new file mode 100644 index 0000000000..a20047c578 --- /dev/null +++ b/src/Support/SmallVectorHelper.hpp @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===----- SmallVectorHelper.hpp - Helper functions llvm::SmallVector -----===// +// +// Copyright 2019-2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains helper functions for taking subsets of llvm::SmallVector. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" + +//===----------------------------------------------------------------------===// +// Select the first few elements of a vector, until "untilNum" (inclusively) +// Negative numbers count from the back of the vector. + +// Note: because it is inclusively, it is impossible to have an empty list. + +template +llvm::SmallVector firstFew(mlir::ValueRange vec, int64_t untilNum) { + llvm::SmallVector res; + int64_t size = vec.size(); + if (untilNum < 0) + untilNum += size; + // If untilNum<0... we get an empty vector, that is ok. + assert(untilNum < size && "out of bound"); + for (int64_t i = 0; i <= untilNum; ++i) + res.emplace_back(vec[i]); + return res; +} + +template +llvm::SmallVector firstFew(mlir::ArrayRef vec, int64_t untilNum) { + llvm::SmallVector res; + int64_t size = vec.size(); + if (untilNum < 0) + untilNum += size; + // If untilNum<0... we get an empty vector, that is ok. + assert(untilNum < size && "out of bound"); + for (int64_t i = 0; i <= untilNum; ++i) + res.emplace_back(vec[i]); + return res; +} + +template +llvm::SmallVector firstFew( + llvm::SmallVectorImpl &vec, int64_t untilNum) { + llvm::SmallVector res; + int64_t size = vec.size(); + if (untilNum < 0) + untilNum += size; + // If untilNum<0... we get an empty vector, that is ok. + assert(untilNum < size && "out of bound"); + for (int64_t i = 0; i <= untilNum; ++i) + res.emplace_back(vec[i]); + return res; +} + +//===----------------------------------------------------------------------===// +// Select the last few elements of a vector, from "untilNum" (inclusively) +// Negative numbers count from the back of the vector. + +template +llvm::SmallVector lastFew(mlir::ValueRange vec, int64_t fromNum) { + llvm::SmallVector res; + int64_t size = vec.size(); + if (fromNum < 0) + fromNum += size; + // If fromNum>= size... we get an empty vector, that is ok. + assert(fromNum >= 0 && "out of bound"); + for (int64_t i = fromNum; i < size; ++i) + res.emplace_back(vec[i]); + return res; +} + +template +llvm::SmallVector lastFew(mlir::ArrayRef vec, int64_t fromNum) { + llvm::SmallVector res; + int64_t size = vec.size(); + if (fromNum < 0) + fromNum += size; + // If fromNum>= size... we get an empty vector, that is ok. + assert(fromNum >= 0 && "out of bound"); + for (int64_t i = fromNum; i < size; ++i) + res.emplace_back(vec[i]); + return res; +} + +template +llvm::SmallVector lastFew( + llvm::SmallVectorImpl &vec, int64_t fromNum) { + llvm::SmallVector res; + int64_t size = vec.size(); + if (fromNum < 0) + fromNum += size; + // If fromNum>= size... we get an empty vector, that is ok. + assert(fromNum >= 0 && "out of bound"); + for (int64_t i = fromNum; i < size; ++i) + res.emplace_back(vec[i]); + return res; +} diff --git a/src/Support/SuppressWarnings.h b/src/Support/SuppressWarnings.h index a4a39b7b27..b44b485bb3 100644 --- a/src/Support/SuppressWarnings.h +++ b/src/Support/SuppressWarnings.h @@ -4,7 +4,7 @@ //====--------------- SuppressWarnings.h - Suppress Warnings --------------===// // -// Copyright 2021 The IBM Research Authors. +// Copyright 2021-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_SUPPRESS_WARNINGS_H +#define ONNX_MLIR_SUPPRESS_WARNINGS_H // clang-format off #if defined(SUPPRESS_THIRD_PARTY_WARNINGS) @@ -43,3 +44,4 @@ #define SUPPRESS_WARNINGS_POP #endif // clang-format on +#endif diff --git a/src/Support/TypeUtilities.hpp b/src/Support/TypeUtilities.hpp index 378e7ed376..951e7a0378 100644 --- a/src/Support/TypeUtilities.hpp +++ b/src/Support/TypeUtilities.hpp @@ -4,7 +4,7 @@ //====---------- TypeUtilities.hpp - functions related to MLIR Type -------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_TYPE_UTILITIES_H +#define ONNX_MLIR_TYPE_UTILITIES_H #include "mlir/IR/BuiltinTypes.h" @@ -40,3 +41,4 @@ bool sameEncodingAttr(mlir::Type t1, mlir::Type t2); unsigned getIntOrFloatByteWidth(mlir::Type ty); } // namespace onnx_mlir +#endif diff --git a/src/Tools/binary-decoder/BinaryDecoder.cpp b/src/Tools/binary-decoder/BinaryDecoder.cpp index 030376b58f..6cc6eb39c8 100644 --- a/src/Tools/binary-decoder/BinaryDecoder.cpp +++ b/src/Tools/binary-decoder/BinaryDecoder.cpp @@ -98,7 +98,7 @@ int main(int argc, char **argv) { llvm::sys::fs::remove(Filename); #define PRINT_BUFFER_FOR_TYPE(ONNX_TYPE, CPP_TYPE) \ - if (DataType == ONNX_TYPE) \ + if (DataType == (ONNX_TYPE)) \ return printBuffer(buffer); PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::BOOL, bool); diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index 24d2631e47..b8285fcb69 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -71,7 +71,7 @@ void registerOMPasses(int optLevel) { []() -> std::unique_ptr { return createInstrumentPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return createInstrumentONNXSignaturePass(); + return createInstrumentONNXSignaturePass("NONE"); }); mlir::registerPass([]() -> std::unique_ptr { diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.hpp b/src/Tools/onnx-mlir-opt/RegisterPasses.hpp index 640c620273..1491532f51 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.hpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.hpp @@ -4,11 +4,12 @@ //===------------------------- RegisterPasses.hpp -------------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_REGISTER_PASSES_H +#define ONNX_MLIR_REGISTER_PASSES_H namespace onnx_mlir { @@ -16,3 +17,4 @@ namespace onnx_mlir { void registerPasses(int optLevel); } // namespace onnx_mlir +#endif diff --git a/src/Transform/ProcessScfParallelPrivate.hpp b/src/Transform/ProcessScfParallelPrivate.hpp index 9b7c52457e..fe6428c92c 100644 --- a/src/Transform/ProcessScfParallelPrivate.hpp +++ b/src/Transform/ProcessScfParallelPrivate.hpp @@ -4,7 +4,7 @@ //===- ProcessAffineParallelPrivate.hpp - Handle parallel private data ----===// // -// Copyright 2023 The IBM Research Authors. +// Copyright 2023-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_PROCESS_AFFINE_PARALLEL_PRIVATE_H +#define ONNX_MLIR_PROCESS_AFFINE_PARALLEL_PRIVATE_H #include "mlir/IR/PatternMatch.h" @@ -24,3 +25,4 @@ namespace onnx_mlir { void getParallelPrivateScfToScfPatterns(mlir::RewritePatternSet &patterns); } // namespace onnx_mlir +#endif diff --git a/src/Version/Version.hpp b/src/Version/Version.hpp index c3ee74b40b..8b85bef7fd 100644 --- a/src/Version/Version.hpp +++ b/src/Version/Version.hpp @@ -4,7 +4,7 @@ //===-------------------------- Version.hpp -------------------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,7 +12,8 @@ // //===----------------------------------------------------------------------===// -#pragma once +#ifndef ONNX_MLIR_VERSION_H +#define ONNX_MLIR_VERSION_H #include "llvm/Support/raw_ostream.h" #include @@ -66,3 +67,4 @@ std::string getOnnxMlirCommitVersion(); /// given on the command line. void getVersionPrinter(llvm::raw_ostream &os); } // namespace onnx_mlir +#endif diff --git a/src/onnx-mlir.cpp b/src/onnx-mlir.cpp index 5df9da6b9d..be1d40554e 100644 --- a/src/onnx-mlir.cpp +++ b/src/onnx-mlir.cpp @@ -11,8 +11,10 @@ // Implements main for onnx-mlir driver. //===----------------------------------------------------------------------===// +#include #include +#include "mlir/IR/AsmState.h" #include "mlir/Support/Timing.h" #include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerUtils.hpp" @@ -28,7 +30,6 @@ int main(int argc, char *argv[]) { mlir::registerAsmPrinterCLOptions(); mlir::registerMLIRContextCLOptions(); mlir::registerPassManagerCLOptions(); - mlir::registerAsmPrinterCLOptions(); llvm::cl::SetVersionPrinter(getVersionPrinter); @@ -74,7 +75,12 @@ int main(int argc, char *argv[]) { } loadDialects(context); setupTiming.stop(); - std::string msg = "Importing ONNX Model to MLIR Module"; + // Add the short inputFilename to the first compile phase printout so that we + // may better determine which compilation we are dealing with. + std::filesystem::path p(inputFilename); + std::string modelShortName = p.filename(); + std::string msg = + "Importing ONNX Model to MLIR Module from \"" + modelShortName + "\""; showCompilePhase(msg); auto inputFileTiming = rootTimingScope.nest("[onnx-mlir] " + msg); mlir::OwningOpRef module; diff --git a/test/accelerators/NNPA/backend/CMakeLists.txt b/test/accelerators/NNPA/backend/CMakeLists.txt index 304767d7ef..cb471514dc 100644 --- a/test/accelerators/NNPA/backend/CMakeLists.txt +++ b/test/accelerators/NNPA/backend/CMakeLists.txt @@ -362,12 +362,12 @@ set(NNPA_TEST_LIST # ==LIM== Input tensor must be less than or equal to 4 dimensions. # Model - # test_densenet121_cpu # accurary error - #test_inception_v1_cpu,zdnn_conv2d - #test_resnet50_cpu,zdnn_conv2d - #test_shufflenet_cpu,zdnn_matmul_op_ext - #test_squeezenet_cpu,zdnn_conv - #test_vgg19_cpu,zdnn_conv + test_densenet121_cpu,zdnn_conv2d + test_inception_v1_cpu,zdnn_conv2d + test_resnet50_cpu,zdnn_conv2d + test_shufflenet_cpu,zdnn_matmul_op_ext + # test_squeezenet_cpu,zdnn_conv # got NaN results + test_vgg19_cpu,zdnn_conv ) set(ENV_TEST_CASE_BY_USER "") foreach(test_name IN LISTS NNPA_TEST_LIST) @@ -394,6 +394,9 @@ add_custom_target(check-onnx-backend-nnpa COMMAND TEST_INSTRUCTION_CHECK=true ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-nnpa + # Needed for convolution models to avoid NaN outputs. + # Remove this if saturation is enabled by default. + TEST_COMPILE_ARGS="--nnpa-saturation=true" ${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py @@ -405,6 +408,9 @@ add_custom_target(check-onnx-backend-dynamic-nnpa ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-dynamic-nnpa TEST_INSTRUCTION_CHECK=true TEST_DYNAMIC=true + # Needed for convolution models to avoid NaN outputs. + # Remove this if saturation is enabled by default. + TEST_COMPILE_ARGS="--nnpa-saturation=true" ${NNPA_TESTS_ENVS_DYNAMIC} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py @@ -418,6 +424,9 @@ add_custom_target(check-onnx-backend-constant-nnpa # TEST_INSTRUCTION_CHECK=true ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-constant-nnpa TEST_CONSTANT=true + # Needed for convolution models to avoid NaN outputs. + # Remove this if saturation is enabled by default. + TEST_COMPILE_ARGS="--nnpa-saturation=true" ${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py @@ -427,6 +436,9 @@ add_custom_target(check-onnx-backend-constant-nnpa add_custom_target(check-onnx-backend-compilerlib-nnpa COMMAND TEST_COMPILERLIB=true ONNX_HOME=${CMAKE_CURRENT_BINARY_DIR} + # Needed for convolution models to avoid NaN outputs. + # Remove this if saturation is enabled by default. + TEST_COMPILE_ARGS="--nnpa-saturation=true" ${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py diff --git a/test/backend/common.py b/test/backend/common.py index b2a63d76dd..a9b077eae7 100644 --- a/test/backend/common.py +++ b/test/backend/common.py @@ -142,6 +142,13 @@ def compile_model(model, emit): command_list.append(model_name) command_list.append("-o=" + exec_base) + # Additional args passed in by TEST_COMPILE_ARGS + # Args are separated by ';' + additional_args = os.getenv("TEST_COMPILE_ARGS") + if additional_args is not None: + compile_args = additional_args.split(";") + command_list += compile_args + # Call frontend to process model_name.onnx, bit code will be generated. dynamic_inputs_dims = determine_dynamic_parameters(name) if args.verbose: diff --git a/test/compilerlib/CompilerLibTest.cpp b/test/compilerlib/CompilerLibTest.cpp index e57bd0bcf8..544cbcf018 100644 --- a/test/compilerlib/CompilerLibTest.cpp +++ b/test/compilerlib/CompilerLibTest.cpp @@ -22,12 +22,12 @@ bool compileFromFile = false; } #define PARSE_ARG(NAME, FLAG) \ if (arg.find(FLAG) == 0) { \ - NAME = arg.substr(sizeof(FLAG)); \ + (NAME) = arg.substr(sizeof(FLAG)); \ return true; \ } #define PARSE_FLAG(NAME, FLAG) \ if (arg.find(FLAG) == 0) { \ - NAME = true; \ + (NAME) = true; \ return true; \ } #define PARSE_UNSUPPORTED_FLAG(FLAG) \ @@ -106,4 +106,4 @@ int main(int argc, char *argv[]) { free(compiledFilename); free(errorMessage); return retVal; -} \ No newline at end of file +} diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir index 9b1bd2935d..81ef00353d 100644 --- a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir @@ -39,13 +39,13 @@ module { // CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [2, 3, 1, 0]} : (tensor<8x1x5x5xf32>) -> tensor<5x5x1x8xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_9_:%.+]] = "zhigh.Stick"([[VAR_8_]]) {layout = "HWCK"} : (tensor<5x5x1x8xf32>) -> tensor<5x5x1x8xf16, #zhigh.layout<{dataLayout = "HWCK"}>> -// CHECK-DAG: [[VAR_10_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<8xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK-DAG: [[VAR_10_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "1D", value = dense<[-0.161539719, -0.433835655, 0.091641359, -0.0168522168, -0.0650264397, -0.131737873, 0.0204175506, -0.121110231]> : tensor<8xf32>} : () -> tensor<8xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK: [[VAR_11_:%.+]] = "zhigh.Conv2D"([[VAR_7_]], [[VAR_9_]], [[VAR_10_]]) {act_func = "ACT_RELU", kernel_shape = [5, 5], padding_type = "SAME_PADDING", strides = [1, 1]} : (tensor<1x28x28x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<5x5x1x8xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<8xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<1x28x28x8xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK-DAG: [[VAR_12_:%.+]] = "zhigh.MaxPool2D"([[VAR_11_]]) {kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [2, 2]} : (tensor<1x28x28x8xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x14x14x8xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Transpose"([[VAR_1_]]) {perm = [2, 3, 1, 0]} : (tensor<16x8x5x5xf32>) -> tensor<5x5x8x16xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_14_:%.+]] = "zhigh.Stick"([[VAR_13_]]) {layout = "HWCK"} : (tensor<5x5x8x16xf32>) -> tensor<5x5x8x16xf16, #zhigh.layout<{dataLayout = "HWCK"}>> -// CHECK-DAG: [[VAR_15_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK-DAG: [[VAR_15_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "1D", value = dense<[-0.0822488219, -0.108868778, -0.141039595, -0.204869166, -0.17913565, -0.215438381, -0.133805066, -0.195724562, -0.268250644, -0.258212209, -0.0761560649, 0.0132841459, -0.00444464432, -0.414740831, -0.17879115, -0.0386558883]> : tensor<16xf32>} : () -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK: [[VAR_16_:%.+]] = "zhigh.Conv2D"([[VAR_12_]], [[VAR_14_]], [[VAR_15_]]) {act_func = "ACT_RELU", kernel_shape = [5, 5], padding_type = "SAME_PADDING", strides = [1, 1]} : (tensor<1x14x14x8xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<5x5x8x16xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<1x14x14x16xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK: [[VAR_17_:%.+]] = "zhigh.MaxPool2D"([[VAR_16_]]) {kernel_shape = [3, 3], padding_type = "VALID_PADDING", strides = [3, 3]} : (tensor<1x14x14x16xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x4x4x16xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK: [[VAR_18_:%.+]] = "zhigh.Unstick"([[VAR_17_]]) : (tensor<1x4x4x16xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x16x4x4xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir index 60ae647a83..d4af23b9eb 100644 --- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir @@ -20,6 +20,6 @@ func.func @test_zlow_softmax_constant_shape() -> () { // CHECK: %[[DIM0:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: %[[DIM1:.*]] = llvm.mlir.constant(5 : i64) : i64 // CHECK: %[[DIM2:.*]] = llvm.mlir.constant(10 : i64) : i64 - // CHECK: llvm.call @zdnn_init_pre_transformed_desc({{.*}}, {{.*}}, {{.*}}, %[[DIM0]], %[[DIM1]], %[[DIM2]]) : (i64, i64, !llvm.ptr, i64, i64, i64) -> () + // CHECK: llvm.call @zdnn_init_pre_transformed_desc({{.*}}, {{.*}}, {{.*}}, %[[DIM0]], %[[DIM1]], %[[DIM2]]) vararg(!llvm.func) : (i64, i64, !llvm.ptr, i64, i64, i64) -> () } diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-typed-pointer.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-typed-pointer.mlir index fb357ea4a4..f0ea3355aa 100644 --- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-typed-pointer.mlir +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-typed-pointer.mlir @@ -62,7 +62,7 @@ func.func @test_stick() -> () { // CHECK: [[UNSTICKIFIED:%.+]] = llvm.extractvalue [[UNSTICKIFIED_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[UNSTICKIFIED_I8PTR:%.+]] = llvm.bitcast [[UNSTICKIFIED]] : !llvm.ptr to !llvm.ptr // CHECK: [[ZTENSOR_I8PTR:%.+]] = llvm.bitcast [[ZTENSOR]] : !llvm.ptr to !llvm.ptr - // CHECK: {{.*}} = llvm.call @zdnn_transform_ztensor([[ZTENSOR_I8PTR]], [[UNSTICKIFIED_I8PTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 + // CHECK: {{.*}} = llvm.call @zdnn_transform_ztensor([[ZTENSOR_I8PTR]], [[UNSTICKIFIED_I8PTR]]) vararg(!llvm.func) : (!llvm.ptr, !llvm.ptr) -> i32 // CHECK: llvm.return } diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir index 9f928ffc31..2307680415 100644 --- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir @@ -61,7 +61,7 @@ func.func @test_stick() -> () { // CHECK: [[UNSTICKIFIED:%.+]] = llvm.extractvalue [[UNSTICKIFIED_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[UNSTICKIFIED_I8PTR:%.+]] = llvm.bitcast [[UNSTICKIFIED]] : !llvm.ptr to !llvm.ptr // CHECK: [[ZTENSOR_I8PTR:%.+]] = llvm.bitcast [[ZTENSOR]] : !llvm.ptr to !llvm.ptr - // CHECK: {{.*}} = llvm.call @zdnn_transform_ztensor([[ZTENSOR_I8PTR]], [[UNSTICKIFIED_I8PTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 + // CHECK: {{.*}} = llvm.call @zdnn_transform_ztensor([[ZTENSOR_I8PTR]], [[UNSTICKIFIED_I8PTR]]) vararg(!llvm.func) : (!llvm.ptr, !llvm.ptr) -> i32 // CHECK: llvm.return } diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/compiler-stick-unstick.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/compiler-stick-unstick.mlir new file mode 100644 index 0000000000..b8cef7cf2b --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/compiler-stick-unstick.mlir @@ -0,0 +1,76 @@ +// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @should_lower_to_zlow(%arg0: tensor<1x3x5x7xf32>) -> tensor<*xf32> { + %0 = "zhigh.Stick"(%arg0) {layout = "NHWC"} : (tensor<1x3x5x7xf32>) -> tensor<*xf16> + %1 = "zhigh.Unstick"(%0) : (tensor<*xf16>) -> tensor<*xf32> + return %1 : tensor<*xf32> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3 floordiv 64, d1, d2 floordiv 32, d2 mod 32, d3 mod 64)> +// CHECK-LABEL: func.func @should_lower_to_zlow +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x5x7xf32>) -> memref<1x3x5x7xf32> { +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x5x7x3xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1x5x7x3xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 5, [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 7){ +// CHECK: [[VAR_2_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2, [[VAR_2_]]#3] : memref<1x3x5x7xf32> +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_1_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#2, [[VAR_2_]]#3, [[VAR_2_]]#1] : memref<1x5x7x3xf32> +// CHECK: } +// CHECK: "zlow.stick"([[RES_1_]], [[RES_]]) {layout = "NHWC"} : (memref<1x5x7x3xf32>, memref<1x5x7x3xf16, #map>) -> () +// CHECK: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1x5x7x3xf32> +// CHECK: "zlow.unstick"([[RES_]], [[RES_]]_1) {layout = "NHWC"} : (memref<1x5x7x3xf16, #map>, memref<1x5x7x3xf32>) -> () +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1x3x5x7xf32> +// CHECK-DAG: [[LOOP_1_:%.+]]:4 = krnl.define_loops 4 +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to 1, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to 5, [[LOOP_1_]]#2 -> [[I_6_:%.+]] = 0 to 7, [[LOOP_1_]]#3 -> [[I_7_:%.+]] = 0 to 3){ +// CHECK: [[VAR_2_1_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2, [[VAR_2_1_]]#3] : memref<1x5x7x3xf32> +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_1_]], [[RES_3_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#3, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2] : memref<1x3x5x7xf32> +// CHECK: } +// CHECK: return [[RES_3_]] : memref<1x3x5x7xf32> +// CHECK: } +} + +// ----- + +func.func @should_lower_to_zlow_unknown_dims(%arg0: tensor<1x?x?x7xf32>) -> tensor<*xf32> { + %0 = "zhigh.Stick"(%arg0) {layout = "NHWC"} : (tensor<1x?x?x7xf32>) -> tensor<*xf16> + %1 = "zhigh.Unstick"(%0) : (tensor<*xf16>) -> tensor<*xf32> + return %1 : tensor<*xf32> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3 floordiv 64, d1, d2 floordiv 32, d2 mod 32, d3 mod 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK-LABEL: func.func @should_lower_to_zlow_unknown_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x?x?x7xf32>) -> memref<1x?x?x7xf32> { +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<1x?x?x7xf32> +// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x?x?x7xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_0_]], [[VAR_dim_]]) {{.*}}: memref<1x?x7x?xf16, #map> +// CHECK-DAG: [[VAR_dim_1_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x?x?x7xf32> +// CHECK-DAG: [[VAR_dim_2_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<1x?x?x7xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc([[VAR_dim_1_]], [[VAR_dim_2_]]) {{.*}}: memref<1x?x7x?xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 +// CHECK-DAG: [[VAR_dim_4_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<1x?x?x7xf32> +// CHECK-DAG: [[VAR_dim_5_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x?x?x7xf32> +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_4_]]), [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to [[MAP_2_]]([[VAR_dim_4_]], [[VAR_dim_5_]]), [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 7){ +// CHECK: [[VAR_2_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2, [[VAR_2_]]#3] : memref<1x?x?x7xf32> +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_1_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#2, [[VAR_2_]]#3, [[VAR_2_]]#1] : memref<1x?x7x?xf32> +// CHECK: } +// CHECK: "zlow.stick"([[RES_1_]], [[RES_]]) {layout = "NHWC"} : (memref<1x?x7x?xf32>, memref<1x?x7x?xf16, #map>) -> () +// CHECK: [[RES_2_:%.+]] = memref.alloc([[VAR_dim_0_]], [[VAR_dim_]]) {{.*}}: memref<1x?x7x?xf32> +// CHECK: "zlow.unstick"([[RES_]], [[RES_]]_6) {layout = "NHWC"} : (memref<1x?x7x?xf16, #map>, memref<1x?x7x?xf32>) -> () +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc([[VAR_dim_]], [[VAR_dim_]]_0) {{.*}}: memref<1x?x?x7xf32> +// CHECK-DAG: [[LOOP_1_:%.+]]:4 = krnl.define_loops 4 +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to 1, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_0_]]), [[LOOP_1_]]#2 -> [[I_6_:%.+]] = 0 to 7, [[LOOP_1_]]#3 -> [[I_7_:%.+]] = 0 to [[MAP_2_]]([[VAR_dim_0_]], [[VAR_dim_]])){ +// CHECK: [[VAR_2_1_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2, [[VAR_2_1_]]#3] : memref<1x?x7x?xf32> +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_1_]], [[RES_3_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#3, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2] : memref<1x?x?x7xf32> +// CHECK: } +// CHECK: return [[RES_3_]] : memref<1x?x?x7xf32> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir index 2d2983ba07..a34585cc01 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir @@ -189,7 +189,7 @@ func.func @conv_same_padding_no_bias_unknown_dims(%arg0: tensor<1x32x32x3xf16, # // CHECK: krnl.store [[VAR_c1_i64_]], [[RES_1_]]{{.}}[[VAR_c4_]]{{.}} : memref<7xi64> // CHECK: krnl.store [[VAR_c32_i64_]], [[RES_1_]]{{.}}[[VAR_c5_]]{{.}} : memref<7xi64> // CHECK: krnl.store [[VAR_c32_i64_]], [[RES_1_]]{{.}}[[VAR_c6_]]{{.}} : memref<7xi64> -// CHECK: [[VAR_2_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_0", shape = [1, 1, 1, 1, 32, 64], value = dense_resource : tensor<4096xi8>} : () -> memref<1x1x1x1x32x64xf16> +// CHECK: [[VAR_2_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_0", offset = 0 : i64, shape = [1, 1, 1, 1, 32, 64]} : () -> memref<1x1x1x1x32x64xf16> // CHECK: "zlow.conv2d"([[PARAM_0_]], [[PARAM_1_]], [[VAR_2_]], [[RES_1_]], [[RES_]]) {act_func = "ACT_NONE", kernel_shape = [2, 2], padding_type = "SAME_PADDING", strides = [1, 1]} : (memref<1x32x32x3xf16, #map>, memref<2x2x3x1xf16, #map1>, memref<1x1x1x1x32x64xf16>, memref<7xi64>, memref<1x32x32x1xf16, #map>) -> () // CHECK: return [[RES_]] : memref<1x32x32x1xf16, #map> // CHECK: } diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stick-unstick.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stick-unstick.mlir index cf23fb6de6..22a67eec40 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stick-unstick.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stick-unstick.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<1x3x5x7xf32>) -> tensor<*xf32> { %0 = "zhigh.Stick"(%arg0) {layout = "NHWC"} : (tensor<1x3x5x7xf32>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir index 7bf9766d88..06a7b68028 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir @@ -2,24 +2,17 @@ module { func.func @remove_stick_2d() -> tensor<2x3xf32> { - %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>> + %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "2D", value = dense<[[0., 1., 2.], [3., 4., 5.]]> : tensor<2x3xf32>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>> %1 = "zhigh.Unstick"(%0) : (tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<2x3xf32> return %1 : tensor<2x3xf32> } } -{-# - dialect_resources: { - builtin: { - zhigh: "0x0100000000003E00400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004100420042800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - } - } -#-} // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 32, d0 mod 32, d1 mod 64)> // CHECK-LABEL: func @remove_stick_2d // CHECK-SAME: () -> memref<2x3xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_0", shape = [1, 1, 1, 1, 32, 64], value = dense_resource : tensor<4096xi8>} : () -> memref<2x3xf16, [[MAP_0_]]> +// CHECK-DAG: [[VAR_0_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, layout = "2D", name = "constant_stickify_0", offset = 0 : i64, shape = [1, 1, 1, 1, 32, 64], value = dense<{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}> : tensor<2x3xf32>} : () -> memref<2x3xf16, #map> // CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x3xf32> @@ -27,8 +20,3 @@ module { // CHECK: return [[RES_]] : memref<2x3xf32> // CHECK: } -// CHECK: dialect_resources: { -// CHECK-NEXT: builtin: { -// CHECK-NEXT: zhigh: "0x0100000000003E00400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004100420042800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/test-datalayout.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/test-datalayout.mlir index 1237ccd36d..508a819ccc 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/test-datalayout.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/test-datalayout.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow_1d(%arg0: tensor<7xf32>) -> tensor<*xf16> { %0 = "zhigh.Stick"(%arg0) {layout = "1D"} : (tensor<7xf32>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/driver/ccfd.mlir b/test/mlir/accelerators/nnpa/driver/ccfd.mlir index d8a2e57614..3c66f67da7 100644 --- a/test/mlir/accelerators/nnpa/driver/ccfd.mlir +++ b/test/mlir/accelerators/nnpa/driver/ccfd.mlir @@ -1,4 +1,4 @@ -// RUN: ccfd=$(dirname %s)/ccfd.onnx && curl -L https://github.com/IBM/ai-on-z-fraud-detection/raw/main/onnx%20models/ccf_lstm_static_tf2onnx_OS_new.onnx -o ${ccfd} && onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --printIR -tag="test" ${ccfd} | FileCheck %s && rm -rf ${ccfd} +// RUN: ccfd=$(dirname %s)/ccfd.onnx && curl -L https://github.com/IBM/ai-on-z-fraud-detection/raw/main/onnx%20models/ccf_lstm_static_tf2onnx_OS_new.onnx -o ${ccfd} && onnx-mlir --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR -tag="test" ${ccfd} | FileCheck %s && rm -rf ${ccfd} // COM: This test is to check regression on the IBM CCFD model. // COM: We expect that there are only one zlow.stick for the input and one zlow.unstick for the output. diff --git a/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor-num2.mlir b/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor-num2.mlir index ac3e9e5074..dc26676eb4 100644 --- a/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor-num2.mlir +++ b/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor-num2.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --printIR -tag="test" %s | FileCheck %s +// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR -tag="test" %s | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor.mlir b/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor.mlir index ab72847346..d5f40bbc1f 100644 --- a/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor.mlir +++ b/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --printIR -tag="test" %s | FileCheck %s +// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR -tag="test" %s | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir b/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir index 863efd1ee4..091663a32b 100644 --- a/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir +++ b/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir @@ -14,6 +14,6 @@ func.func @test_matmul_add_add(%arg0: tensor, %arg1: tensor<768x768 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<768x768xf32>) -> tensor { // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor) -> tensor> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<768x768xf32>) -> tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>> -// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<49152xi8>} : () -> tensor<768xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "1D", value = dense<5.000000e+00> : tensor<768xf32>} : () -> tensor<768xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<768xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor> } diff --git a/test/mlir/accelerators/nnpa/driver/softmax-matmul-in-attention-layer.mlir b/test/mlir/accelerators/nnpa/driver/softmax-matmul-in-attention-layer.mlir index cd7d6e5f8c..8a6d0d0ede 100644 --- a/test/mlir/accelerators/nnpa/driver/softmax-matmul-in-attention-layer.mlir +++ b/test/mlir/accelerators/nnpa/driver/softmax-matmul-in-attention-layer.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --printIR %s | FileCheck %s +// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR %s | FileCheck %s // Check whether the compiler can remove unstick/stick so that the output of zdnn softmax is passed directly to zdnn matmul. func.func @softmax_matmul(%arg0: tensor) -> tensor { diff --git a/test/mlir/accelerators/nnpa/transform/fold-std-alloc.mlir b/test/mlir/accelerators/nnpa/transform/fold-std-alloc.mlir index 92ce0e6866..20438651da 100644 --- a/test/mlir/accelerators/nnpa/transform/fold-std-alloc.mlir +++ b/test/mlir/accelerators/nnpa/transform/fold-std-alloc.mlir @@ -1,5 +1,7 @@ // RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --fold-std-alloc %s -split-input-file | FileCheck %s +// ----- + func.func @should_fold() -> memref<3xi64> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir index 609cab1aec..ab1c25860c 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir @@ -9,16 +9,11 @@ func.func @remove_stick_1d() -> tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}> %res = "zhigh.Stick"(%inp) {layout = "1D"} : (tensor<6xf32>) -> tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}>> return %res : tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}>> - // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "1D", value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<6xf32>} : () -> tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x0100000000003E00400041004200428000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -31,16 +26,10 @@ func.func @remove_stick_2d() -> tensor<2x3xf32> { %res = "zhigh.Unstick"(%st) {layout = "2D"} : (tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<2x3xf32> return %res : tensor<2x3xf32> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "2D", value = dense<{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}> : tensor<2x3xf32>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x0100000000003E00400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004100420042800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -52,16 +41,10 @@ func.func @remove_stick_2ds() -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D %res = "zhigh.Stick"(%inp) {layout = "2DS"} : (tensor<2x3xf32>) -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2DS"}>> return %res : tensor<2x3xf16, #zhigh.layout<{dataLayout = "2DS"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<8192xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2DS"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "2DS", value = dense<{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}> : tensor<2x3xf32>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2DS"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x0100000000003E0040000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000041004200428000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -73,16 +56,10 @@ func.func @remove_stick_3d() -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3 %res = "zhigh.Stick"(%inp) {layout = "3D"} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3D"}>> return %res : tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3D"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3D"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "3D", value = dense<{{.}}{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}> : tensor<1x2x3xf32>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x0100000000003E00400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004100420042800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -94,16 +71,10 @@ func.func @remove_stick_3ds() -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = " %res = "zhigh.Stick"(%inp) {layout = "3DS"} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3DS"}>> return %res : tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3DS"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3DS"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "3DS", value = dense<{{.}}{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}> : tensor<1x2x3xf32>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x0100000000003E00400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004100420042800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -115,16 +86,10 @@ func.func @remove_stick_4d() -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = %res = "zhigh.Stick"(%inp) {layout = "4D"} : (tensor<1x1x2x3xf32>) -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "4D"}>> return %res : tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "4D"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "4D"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "4D", value = dense<{{.}}[{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}]> : tensor<1x1x2x3xf32>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "4D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x0100000000003E00400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004100420042800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -136,16 +101,10 @@ func.func @remove_stick_nhwc() -> tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout %res = "zhigh.Stick"(%inp) {layout = "NHWC"} : (tensor<1x1x2x3xf32>) -> tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>> return %res : tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<8192xi8>} : () -> tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "NHWC", value = dense<{{.}}[{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}]> : tensor<1x1x2x3xf32>} : () -> tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x0100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003E0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000041000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000428000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -157,16 +116,10 @@ func.func @remove_stick_nchw() -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout %res = "zhigh.Stick"(%inp) {layout = "NCHW"} : (tensor<1x1x2x3xf32>) -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "NCHW"}>> return %res : tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "NCHW"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<8192xi8>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "NCHW"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "NCHW", value = dense<{{.}}[{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}]> : tensor<1x1x2x3xf32>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "NCHW"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x0100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003E0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000041000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000428000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -178,16 +131,10 @@ func.func @remove_stick_cnnk_hwck() -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLa %res = "zhigh.Stick"(%inp) {layout = "HWCK"} : (tensor<1x1x2x3xf32>) -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "HWCK"}>> return %res : tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "HWCK"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "HWCK"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "HWCK", value = dense<{{.}}[{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}]> : tensor<1x1x2x3xf32>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "HWCK"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x0100000000003E00400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004100420042800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -297,16 +244,10 @@ func.func @out_of_range_minimum() -> tensor<1xf16, #zhigh.layout<{dataLayout = " %res = "zhigh.Stick"(%inp) {layout = "1D"} : (tensor<1xf32>) -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> return %res : tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> - // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> + // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "1D", value = dense<-3.402820e+38> : tensor<1xf32>} : () -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x01000000FFFE0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -317,14 +258,9 @@ func.func @out_of_range_maximum() -> tensor<1xf16, #zhigh.layout<{dataLayout = " %res = "zhigh.Stick"(%inp) {layout = "1D"} : (tensor<1xf32>) -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> return %res : tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> - // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> + // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, layout = "1D", value = dense<3.402820e+38> : tensor<1xf32>} : () -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x010000007FFE0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - // CHECK-NEXT: } - // CHECK-NEXT: } } diff --git a/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir b/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir new file mode 100644 index 0000000000..e3761e8bd6 --- /dev/null +++ b/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir @@ -0,0 +1,283 @@ +// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --zlow-stick-expansion %s -split-input-file | FileCheck %s + +// ----- + + +#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +func.func @test_stick_expansion_with_sat(%arg0: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map> + "zlow.stick"(%arg0, %alloc) {layout = "3DS", saturation = -1 : si64} : (memref<16x8x128xf32>, memref<16x8x128xf16, #map>) -> () + return %alloc : memref<16x8x128xf16, #map> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 64)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s1 floordiv 64)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<()[s0, s1] -> (s1 + 8)> +// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<()[s0, s1] -> (s1 + 16)> +// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<()[s0, s1] -> (s1 + 24)> +// CHECK-LABEL: func.func @test_stick_expansion_with_sat +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { +// CHECK-DAG: [[CST_28_:%.+]] = arith.constant 28 : index +// CHECK-DAG: [[CST_24_:%.+]] = arith.constant 24 : index +// CHECK-DAG: [[CST_20_:%.+]] = arith.constant 20 : index +// CHECK-DAG: [[CST_16_:%.+]] = arith.constant 16 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<8.57315738E+9> : vector<4xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<-8.57315738E+9> : vector<4xf32> +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf16, #map> +// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16> +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){ +// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_1_]]#2] +// CHECK: [[VAR_3_:%.+]] = krnl.get_linear_offset_index [[RES_]] at {{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}} : memref<16x8x128xf16, #map> +// CHECK: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]](){{.}}[[VAR_1_]]#2, [[VAR_3_]]{{.}} +// CHECK: krnl.prefetch [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, read, locality<1>, data : memref<16x8x128xf32> +// CHECK: krnl.prefetch [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, write, locality<1>, data : memref<16x8x128xf16, #map> +// CHECK: affine.for [[I_3_:%.+]] = 0 to 64 step 32 { +// CHECK: [[VAR_5_:%.+]] = affine.apply [[MAP_3_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_5_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_5_]], [[CST_4_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_7_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = arith.addi [[VAR_5_]], [[CST_8_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_9_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_5_]], [[CST_12_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]1] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_5_]], [[CST_16_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]3] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addi [[VAR_5_]], [[CST_20_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_5_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]5] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_5_]], [[CST_24_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_6_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]7] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addi [[VAR_5_]], [[CST_28_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_7_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]9] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = arith.minnumf [[LOAD_PARAM_0_MEM_]], [[VAR_cst_]] : vector<4xf32> +// CHECK-DAG: [[VAR_22_:%.+]] = arith.minnumf [[LOAD_PARAM_0_MEM_1_]], [[VAR_cst_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = arith.minnumf [[LOAD_PARAM_0_MEM_2_]], [[VAR_cst_]] : vector<4xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = arith.minnumf [[LOAD_PARAM_0_MEM_3_]], [[VAR_cst_]] : vector<4xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.minnumf [[LOAD_PARAM_0_MEM_4_]], [[VAR_cst_]] : vector<4xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = arith.minnumf [[LOAD_PARAM_0_MEM_5_]], [[VAR_cst_]] : vector<4xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = arith.minnumf [[LOAD_PARAM_0_MEM_6_]], [[VAR_cst_]] : vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_:%.+]] = arith.minnumf [[LOAD_PARAM_0_MEM_7_]], [[VAR_cst_]] : vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.maxnumf [[VAR_21_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = arith.maxnumf [[VAR_22_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = arith.maxnumf [[VAR_23_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = arith.maxnumf [[VAR_24_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK-DAG: [[VAR_33_:%.+]] = arith.maxnumf [[VAR_25_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK-DAG: [[VAR_34_:%.+]] = arith.maxnumf [[VAR_26_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = arith.maxnumf [[VAR_27_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.maxnumf [[VAR_28_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = "zlow.vec_f32_to_dlf16"([[VAR_29_]], [[VAR_30_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> +// CHECK-DAG: [[VAR_38_:%.+]] = "zlow.vec_f32_to_dlf16"([[VAR_31_]], [[VAR_32_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> +// CHECK-DAG: [[VAR_39_:%.+]] = "zlow.vec_f32_to_dlf16"([[VAR_33_]], [[VAR_34_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> +// CHECK: [[VAR_40_:%.+]] = "zlow.vec_f32_to_dlf16"([[VAR_35_]], [[VAR_36_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> +// CHECK: vector.store [[VAR_37_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_3_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_41_:%.+]] = affine.apply [[MAP_4_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: vector.store [[VAR_38_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_4_]]1] : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_42_:%.+]] = affine.apply [[MAP_5_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: vector.store [[VAR_39_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_4_]]2] : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_43_:%.+]] = affine.apply [[MAP_6_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: vector.store [[VAR_40_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_4_]]3] : memref<2x64xf16>, vector<8xf16> +// CHECK: } +// CHECK: } +// CHECK: return [[RES_]] : memref<16x8x128xf16, #map> +// CHECK: } +} + +// ----- + + +#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +func.func @test_stick_expansion_without_sat(%arg0: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map> + "zlow.stick"(%arg0, %alloc) {layout = "3DS", saturation = 0 : si64} : (memref<16x8x128xf32>, memref<16x8x128xf16, #map>) -> () + return %alloc : memref<16x8x128xf16, #map> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 64)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s1 floordiv 64)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<()[s0, s1] -> (s1 + 8)> +// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<()[s0, s1] -> (s1 + 16)> +// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<()[s0, s1] -> (s1 + 24)> +// CHECK-LABEL: func.func @test_stick_expansion_without_sat +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { +// CHECK-DAG: [[CST_28_:%.+]] = arith.constant 28 : index +// CHECK-DAG: [[CST_24_:%.+]] = arith.constant 24 : index +// CHECK-DAG: [[CST_20_:%.+]] = arith.constant 20 : index +// CHECK-DAG: [[CST_16_:%.+]] = arith.constant 16 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf16, #map> +// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16> +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){ +// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_1_]]#2] +// CHECK: [[VAR_3_:%.+]] = krnl.get_linear_offset_index [[RES_]] at {{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}} : memref<16x8x128xf16, #map> +// CHECK: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]](){{.}}[[VAR_1_]]#2, [[VAR_3_]]{{.}} +// CHECK: krnl.prefetch [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, read, locality<1>, data : memref<16x8x128xf32> +// CHECK: krnl.prefetch [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, write, locality<1>, data : memref<16x8x128xf16, #map> +// CHECK: affine.for [[I_3_:%.+]] = 0 to 64 step 32 { +// CHECK: [[VAR_5_:%.+]] = affine.apply [[MAP_3_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_5_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_5_]], [[CST_4_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_7_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = arith.addi [[VAR_5_]], [[CST_8_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_9_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_5_]], [[CST_12_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]1] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_5_]], [[CST_16_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]3] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addi [[VAR_5_]], [[CST_20_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_5_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]5] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_5_]], [[CST_24_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_6_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]7] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addi [[VAR_5_]], [[CST_28_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_7_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]9] : memref<16x8x128xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = "zlow.vec_f32_to_dlf16"([[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_0_MEM_1_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> +// CHECK-DAG: [[VAR_22_:%.+]] = "zlow.vec_f32_to_dlf16"([[LOAD_PARAM_0_MEM_2_]], [[LOAD_PARAM_0_MEM_3_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> +// CHECK-DAG: [[VAR_23_:%.+]] = "zlow.vec_f32_to_dlf16"([[LOAD_PARAM_0_MEM_4_]], [[LOAD_PARAM_0_MEM_5_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> +// CHECK: [[VAR_24_:%.+]] = "zlow.vec_f32_to_dlf16"([[LOAD_PARAM_0_MEM_6_]], [[LOAD_PARAM_0_MEM_7_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> +// CHECK: vector.store [[VAR_21_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_3_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_25_:%.+]] = affine.apply [[MAP_4_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: vector.store [[VAR_22_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_25_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_26_:%.+]] = affine.apply [[MAP_5_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: vector.store [[VAR_23_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_26_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_27_:%.+]] = affine.apply [[MAP_6_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: vector.store [[VAR_24_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_27_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: } +// CHECK: } +// CHECK: return [[RES_]] : memref<16x8x128xf16, #map> +// CHECK: } +} + +// ----- + + +#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +func.func @test_unstick_expansion(%arg0: memref<16x8x128xf16, #map>) -> memref<16x8x128xf32> { + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf32> + "zlow.unstick"(%arg0, %alloc) {layout = "3DS"} : (memref<16x8x128xf16, #map>, memref<16x8x128xf32>) -> () + return %alloc : memref<16x8x128xf32> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 64)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0)[s0] -> (s0 floordiv 64)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0) -> (d0 + 8)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 + 16)> +// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0) -> (d0 + 24)> +// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: [[MAP_7_:#.+]] = affine_map<()[s0] -> (-s0 + 121)> +// CHECK-DAG: [[MAP_8_:#.+]] = affine_map<()[s0] -> ((-s0) mod 8)> +// CHECK-DAG: [[MAP_9_:#.+]] = affine_map<()[s0] -> (-s0 - (-s0) mod 8 + 128)> +// CHECK-DAG: [[MAP_10_:#.+]] = affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)> +// CHECK-LABEL: func.func @test_unstick_expansion +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf16, #map>) -> memref<16x8x128xf32> { +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_28_:%.+]] = arith.constant 28 : index +// CHECK-DAG: [[CST_24_:%.+]] = arith.constant 24 : index +// CHECK-DAG: [[CST_20_:%.+]] = arith.constant 20 : index +// CHECK-DAG: [[CST_16_:%.+]] = arith.constant 16 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[VAR_true_:%.+]] = arith.constant true +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16> +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){ +// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#2) +// CHECK: [[VAR_3_:%.+]] = krnl.get_linear_offset_index [[PARAM_0_]] at {{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}} : memref<16x8x128xf16, #map> +// CHECK: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#2){{.}}[[VAR_3_]]{{.}} +// CHECK: krnl.prefetch [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, read, locality<1>, data : memref<16x8x128xf16, #map> +// CHECK: krnl.prefetch [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, write, locality<1>, data : memref<16x8x128xf32> +// CHECK: scf.if [[VAR_true_]] { +// CHECK: scf.for [[I_3_:%.+]] = [[CST_0_]] to [[CST_64_]] step [[CST_32_]] { +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_3_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK-DAG: [[VAR_6_:%.+]] = affine.apply [[MAP_3_]]([[I_3_]]) +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_6_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK-DAG: [[VAR_8_:%.+]] = affine.apply [[MAP_4_]]([[I_3_]]) +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_2_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_8_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_5_]]([[I_3_]]) +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_10_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[output1_:%.+]], [[VAR_output2_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[output1_0_:%.+]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_1_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[output1_2_:%.+]], [[VAR_output2_3_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_2_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[output1_4_:%.+]], [[VAR_output2_5_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_3_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_12_:%.+]] = affine.apply [[MAP_6_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} +// CHECK: vector.store [[output1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]2] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_12_]], [[CST_4_]] : index +// CHECK: vector.store [[VAR_output2_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]3] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_14_:%.+]] = arith.addi [[VAR_12_]], [[CST_8_]] : index +// CHECK: vector.store [[output1_0_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]4] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_15_:%.+]] = arith.addi [[VAR_12_]], [[CST_12_]] : index +// CHECK: vector.store [[VAR_output2_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]5] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_16_:%.+]] = arith.addi [[VAR_12_]], [[CST_16_]] : index +// CHECK: vector.store [[output1_2_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]6] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_17_:%.+]] = arith.addi [[VAR_12_]], [[CST_20_]] : index +// CHECK: vector.store [[VAR_output2_3_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]7] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_18_:%.+]] = arith.addi [[VAR_12_]], [[CST_24_]] : index +// CHECK: vector.store [[output1_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]8] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_19_:%.+]] = arith.addi [[VAR_12_]], [[CST_28_]] : index +// CHECK: vector.store [[VAR_output2_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]9] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: } +// CHECK: } else { +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_4_:%.+]] = affine.apply [[MAP_7_]](){{.}}[[VAR_2_]]{{.}} +// CHECK: scf.for [[I_4_:%.+]] = [[CST_0_]] to [[LOAD_VAR_reinterpret_cast_MEM_4_]] step [[CST_8_]] { +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_5_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_4_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[output1_0_]], [[VAR_output2_1_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_5_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_10_1_:%.+]] = affine.apply [[MAP_6_]]([[I_4_]]){{.}}[[VAR_2_]]{{.}} +// CHECK: vector.store [[output1_0_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]0] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = arith.addi [[VAR_10_1_]], [[CST_4_]] : index +// CHECK: vector.store [[VAR_output2_1_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]1] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[VAR_6_1_:%.+]] = affine.apply [[MAP_8_]](){{.}}[[VAR_2_]]{{.}} +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = affine.apply [[MAP_9_]](){{.}}[[VAR_2_]]{{.}} +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_6_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[LOAD_VAR_reinterpret_cast_MEM_1_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[output1_]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_6_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<8xf32> +// CHECK: vector.store [[output1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_output2_1_]], [[RES_1_]]{{.}}[[CST_4_]]{{.}} : memref<8xf32>, vector<4xf32> +// CHECK: scf.for [[I_5_:%.+]] = [[CST_0_]] to [[VAR_6_1_]] step [[CST_1_]] { +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_5_:%.+]] = krnl.load [[RES_1_]]{{.}}[[I_5_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_10_2_:%.+]] = affine.apply [[MAP_10_]]([[I_5_]]){{.}}[[VAR_2_]], [[LOAD_VAR_reinterpret_cast_MEM_1_]]{{.}} +// CHECK: krnl.store [[LOAD_VAR_reinterpret_cast_MEM_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]0] : memref<16x8x128xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return [[RES_]] : memref<16x8x128xf32> +// CHECK: } +} + diff --git a/test/mlir/conversion/krnl_to_llvm/input_verification.mlir b/test/mlir/conversion/krnl_to_llvm/input_verification.mlir index 5ea295be01..1764dfb320 100644 --- a/test/mlir/conversion/krnl_to_llvm/input_verification.mlir +++ b/test/mlir/conversion/krnl_to_llvm/input_verification.mlir @@ -8,7 +8,7 @@ module { func.func @main_graph(%arg0: memref<3x4x5xf32>, %arg1: memref) -> memref<3x4x5xf32> { return %arg0 : memref<3x4x5xf32> } - "krnl.entry_point"() {func = @main_graph, numInputs = 2 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [3 , 4 , 5] , \22name\22 : \22input0\22 }\0A , { \22type\22 : \22f32\22 , \22dims\22 : [-1 , 4 , 5] , \22name\22 : \22input1\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [3 , 4 , 5] , \22name\22 : \22output0\22 }\0A\0A]\00"} : () -> () + "krnl.entry_point"() {func = @main_graph, numInputs = 2 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [3 , 4 , 5] , \22name\22 : \22input0\22 }\0A , { \22type\22 : \22f32\22 , \22dims\22 : [-1 , 4 , 5] , \22name\22 : \22input1\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [3 , 4 , 5], \22name\22 : \22output0\22 }\0A\0A]\00"} : () -> () // CHECK: llvm.func @run_main_graph([[arg0_:.*]]: !llvm.ptr) -> !llvm.ptr { // CHECK-DAG: [[VAR_0:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> @@ -35,7 +35,7 @@ module { // CHECK: [[VAR_19_1_:%.+]] = llvm.icmp "ne" [[VAR_17_1_]], [[VAR_18_1_]] : i64 // CHECK: llvm.cond_br [[VAR_19_1_]], ^bb1, ^bb2 // CHECK: ^bb1: // pred: ^bb0 -// CHECK: llvm.call @printf([[VAR_16_1_]], [[VAR_18_1_]]) : (!llvm.ptr, i64) -> () +// CHECK: llvm.call @printf([[VAR_16_1_]], [[VAR_18_1_]]) vararg(!llvm.func) : (!llvm.ptr, i64) -> () // CHECK: [[VAR_22_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_22_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -46,7 +46,7 @@ module { // CHECK: [[VAR_27_:%.+]] = llvm.icmp "ne" [[VAR_14_1_]], [[VAR_26_]] : i64 // CHECK: llvm.cond_br [[VAR_27_]], ^bb3, ^bb4 // CHECK: ^bb3: // pred: ^bb2 -// CHECK: llvm.call @printf([[VAR_13_1_]]) : (!llvm.ptr) -> () +// CHECK: llvm.call @printf([[VAR_13_1_]]) vararg(!llvm.func) : (!llvm.ptr) -> () // CHECK: [[VAR_29_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_29_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -55,7 +55,7 @@ module { // CHECK: [[VAR_31_1_:%.+]] = llvm.icmp "ne" [[VAR_12_1_]], [[VAR_31_]] : i64 // CHECK: llvm.cond_br [[VAR_31_1_]], ^bb5, ^bb6 // CHECK: ^bb5: // pred: ^bb4 -// CHECK: llvm.call @printf([[VAR_11_1_]], [[VAR_31_]]) : (!llvm.ptr, i64) -> () +// CHECK: llvm.call @printf([[VAR_11_1_]], [[VAR_31_]]) vararg(!llvm.func) : (!llvm.ptr, i64) -> () // CHECK: [[VAR_31_2_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_31_2_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -65,7 +65,7 @@ module { // CHECK: [[VAR_32_2_:%.+]] = llvm.icmp "ne" [[VAR_12_1_]], [[LOAD_VAR_32_MEM_]] : i64 // CHECK: llvm.cond_br [[VAR_32_2_]], ^bb7, ^bb8 // CHECK: ^bb7: // pred: ^bb6 -// CHECK: llvm.call @printf([[LOAD_arg2_MEM_1_]], [[LOAD_VAR_32_MEM_]]) : (!llvm.ptr, i64) -> () +// CHECK: llvm.call @printf([[LOAD_arg2_MEM_1_]], [[LOAD_VAR_32_MEM_]]) vararg(!llvm.func) : (!llvm.ptr, i64) -> () // CHECK: [[VAR_32_3_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_32_3_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -75,7 +75,7 @@ module { // CHECK: [[VAR_40_:%.+]] = llvm.icmp "ne" [[VAR_9_2_]], [[LOAD_VAR_36_MEM_]] : i64 // CHECK: llvm.cond_br [[VAR_40_]], ^bb9, ^bb10 // CHECK: ^bb9: // pred: ^bb8 -// CHECK: llvm.call @printf([[VAR_8_2_]], [[LOAD_VAR_36_MEM_]]) : (!llvm.ptr, i64) -> () +// CHECK: llvm.call @printf([[VAR_8_2_]], [[LOAD_VAR_36_MEM_]]) vararg(!llvm.func) : (!llvm.ptr, i64) -> () // CHECK: [[VAR_41_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_41_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -85,7 +85,7 @@ module { // CHECK: [[VAR_42_0_:%.+]] = llvm.icmp "ne" [[VAR_7_2_]], [[LOAD_VAR_40_MEM_]] : i64 // CHECK: llvm.cond_br [[VAR_42_0_]], ^bb11, ^bb12 // CHECK: ^bb11: // pred: ^bb10 -// CHECK: llvm.call @printf([[VAR_6_2_]], [[LOAD_VAR_40_MEM_]]) : (!llvm.ptr, i64) -> () +// CHECK: llvm.call @printf([[VAR_6_2_]], [[LOAD_VAR_40_MEM_]]) vararg(!llvm.func) : (!llvm.ptr, i64) -> () // CHECK: [[VAR_42_1_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_42_1_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -96,7 +96,7 @@ module { // CHECK: [[VAR_45_:%.+]] = llvm.icmp "ne" [[VAR_14_1_]], [[VAR_44_]] : i64 // CHECK: llvm.cond_br [[VAR_45_]], ^bb13, ^bb14 // CHECK: ^bb13: // pred: ^bb12 -// CHECK: llvm.call @printf([[VAR_5_2_]]) : (!llvm.ptr) -> () +// CHECK: llvm.call @printf([[VAR_5_2_]]) vararg(!llvm.func) : (!llvm.ptr) -> () // CHECK: [[VAR_47_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_47_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -105,7 +105,7 @@ module { // CHECK: [[VAR_49_:%.+]] = llvm.icmp "ne" [[VAR_12_1_]], [[VAR_48_]] : i64 // CHECK: llvm.cond_br [[VAR_49_]], ^bb15, ^bb16 // CHECK: ^bb15: // pred: ^bb14 -// CHECK: llvm.call @printf([[VAR_4_2_]], [[VAR_48_]]) : (!llvm.ptr, i64) -> () +// CHECK: llvm.call @printf([[VAR_4_2_]], [[VAR_48_]]) vararg(!llvm.func) : (!llvm.ptr, i64) -> () // CHECK: [[VAR_50_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_50_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -115,7 +115,7 @@ module { // CHECK: [[VAR_53_:%.+]] = llvm.icmp "slt" [[LOAD_VAR_51_MEM_]], [[VAR_3_2_]] : i64 // CHECK: llvm.cond_br [[VAR_53_]], ^bb17, ^bb18 // CHECK: ^bb17: // pred: ^bb16 -// CHECK: llvm.call @printf([[VAR_2_2_]]) : (!llvm.ptr) -> () +// CHECK: llvm.call @printf([[VAR_2_2_]]) vararg(!llvm.func) : (!llvm.ptr) -> () // CHECK: [[VAR_54_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_54_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -125,7 +125,7 @@ module { // CHECK: [[VAR_55_1_:%.+]] = llvm.icmp "ne" [[VAR_9_2_]], [[LOAD_VAR_55_MEM_]] : i64 // CHECK: llvm.cond_br [[VAR_55_1_]], ^bb19, ^bb20 // CHECK: ^bb19: // pred: ^bb18 -// CHECK: llvm.call @printf([[VAR_1_2_]], [[LOAD_VAR_55_MEM_]]) : (!llvm.ptr, i64) -> () +// CHECK: llvm.call @printf([[VAR_1_2_]], [[LOAD_VAR_55_MEM_]]) vararg(!llvm.func) : (!llvm.ptr, i64) -> () // CHECK: [[VAR_55_2_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_55_2_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr @@ -135,7 +135,7 @@ module { // CHECK: [[VAR_55_4_:%.+]] = llvm.icmp "ne" [[VAR_7_2_]], [[LOAD_VAR_54_MEM_]] : i64 // CHECK: llvm.cond_br [[VAR_55_4_]], ^bb21, ^bb22 // CHECK: ^bb21: // pred: ^bb20 -// CHECK: llvm.call @printf([[VAR_0_2_]], [[LOAD_VAR_54_MEM_]]) : (!llvm.ptr, i64) -> () +// CHECK: llvm.call @printf([[VAR_0_2_]], [[LOAD_VAR_54_MEM_]]) vararg(!llvm.func) : (!llvm.ptr, i64) -> () // CHECK: [[VAR_56_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr // CHECK: llvm.store [[VAR_15_1_]], [[VAR_56_]] : i32, !llvm.ptr // CHECK: llvm.return [[VAR_14_2_]] : !llvm.ptr diff --git a/test/mlir/conversion/krnl_to_llvm/krnl_math_function_lowering.mlir b/test/mlir/conversion/krnl_to_llvm/krnl_math_function_lowering.mlir index 43987a229d..d739869098 100644 --- a/test/mlir/conversion/krnl_to_llvm/krnl_math_function_lowering.mlir +++ b/test/mlir/conversion/krnl_to_llvm/krnl_math_function_lowering.mlir @@ -16,7 +16,9 @@ func.func @test_krnl_erf_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32> // CHECK-LABEL: test_krnl_erf_lowering // CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32> +// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr // CHECK: [[ERF_RES:%.+]] = llvm.call @erff([[SCALAR_IN]]) : (f32) -> f32 @@ -39,7 +41,9 @@ func.func @test_krnl_acos_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32 // CHECK-LABEL: test_krnl_acos_lowering // CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32> +// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr // CHECK: [[ACOS_RES:%.+]] = llvm.call @acosf([[SCALAR_IN]]) : (f32) -> f32 @@ -62,7 +66,9 @@ func.func @test_krnl_acosh_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf3 // CHECK-LABEL: test_krnl_acosh_lowering // CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32> +// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr // CHECK: [[ACOS_RES:%.+]] = llvm.call @acoshf([[SCALAR_IN]]) : (f32) -> f32 @@ -85,7 +91,9 @@ func.func @test_krnl_asin_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32 // CHECK-LABEL: test_krnl_asin_lowering // CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32> +// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr // CHECK: [[ACOS_RES:%.+]] = llvm.call @asinf([[SCALAR_IN]]) : (f32) -> f32 @@ -108,7 +116,9 @@ func.func @test_krnl_asinh_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf3 // CHECK-LABEL: test_krnl_asinh_lowering // CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32> +// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr // CHECK: [[ACOS_RES:%.+]] = llvm.call @asinhf([[SCALAR_IN]]) : (f32) -> f32 @@ -131,7 +141,9 @@ func.func @test_krnl_atan_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32 // CHECK-LABEL: test_krnl_atan_lowering // CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32> +// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr // CHECK: [[ACOS_RES:%.+]] = llvm.call @atanf([[SCALAR_IN]]) : (f32) -> f32 @@ -153,7 +165,9 @@ func.func @test_krnl_atanh_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf3 // CHECK-LABEL: test_krnl_atanh_lowering // CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32> +// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr // CHECK: [[ACOS_RES:%.+]] = llvm.call @atanhf([[SCALAR_IN]]) : (f32) -> f32 @@ -175,7 +189,9 @@ func.func @test_krnl_tan_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32> // CHECK-LABEL: test_krnl_tan_lowering // CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32> +// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr // CHECK: [[ACOS_RES:%.+]] = llvm.call @tanf([[SCALAR_IN]]) : (f32) -> f32 diff --git a/test/mlir/conversion/krnl_to_llvm/reshape.mlir b/test/mlir/conversion/krnl_to_llvm/reshape.mlir index 2a530a347c..97d5374ec5 100644 --- a/test/mlir/conversion/krnl_to_llvm/reshape.mlir +++ b/test/mlir/conversion/krnl_to_llvm/reshape.mlir @@ -18,10 +18,8 @@ func.func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi64>) -> tens // COM: Check that there is no copy but only a new MemRef with a new view, i.e. new sizes and strides. // CHECK-DAG: [[NEW_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> -// CHECK-DAG: [[EXTRACT_1:%.+]] = llvm.extractvalue [[INSERT_7_]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-DAG: [[EXTRACT_2:%.+]] = llvm.extractvalue [[INSERT_7_]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: [[INSERT_8_:%.+]] = llvm.insertvalue [[EXTRACT_1]], [[NEW_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> -// CHECK-DAG: [[INSERT_9_:%.+]] = llvm.insertvalue [[EXTRACT_2]], [[INSERT_8_]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> +// CHECK: [[INSERT_8_:%.+]] = llvm.insertvalue {{.*}}, [[NEW_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> +// CHECK-DAG: [[INSERT_9_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_8_]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> // CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : index) : i64 // CHECK: [[INSERT_10_:%.+]] = llvm.insertvalue [[C0]], [[INSERT_9_]][2] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> // CHECK: [[INSERT_11_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_10_]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> diff --git a/test/mlir/conversion/onnx_to_krnl/ControlFlow/Loop.mlir b/test/mlir/conversion/onnx_to_krnl/ControlFlow/Loop.mlir index 82b1dda65b..49ed94b610 100644 --- a/test/mlir/conversion/onnx_to_krnl/ControlFlow/Loop.mlir +++ b/test/mlir/conversion/onnx_to_krnl/ControlFlow/Loop.mlir @@ -41,32 +41,27 @@ func.func private @test_loop_simple_main_graph(%arg0: tensor, %arg1: tensor // CHECK-DAG: [[CST_1_2_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_1_3_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xi64> -// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_1_4_:%.+]] = arith.constant 1 : index -// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 1){ -// CHECK-DAG: [[VAR_14_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[CST_1_5_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_3_]]{{.}} : memref<1xi64> -// CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = krnl.load [[RES_2_]][] : memref -// CHECK: [[VAR_17_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[LOAD_RES_2_MEM_]] : i64 -// CHECK: krnl.store [[VAR_17_]], [[RES_3_]]{{.}}[[VAR_14_]]{{.}} : memref<1xi64> -// CHECK: } -// CHECK-DAG: [[VAR_9_:%.+]] = builtin.unrealized_conversion_cast [[RES_3_]] : memref<1xi64> to tensor<1xi64> -// CHECK-DAG: [[VAR_10_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : memref to memref +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_2_]]{{.}} : memref<1xi64> +// CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = krnl.load [[RES_2_]][] : memref // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_11_:%.+]] = builtin.unrealized_conversion_cast [[VAR_9_]] : tensor<1xi64> to memref<1xi64> -// CHECK-DAG: [[LOAD_VAR_10_MEM_:%.+]] = krnl.load [[VAR_10_]][] : memref -// CHECK: krnl.store [[LOAD_VAR_10_MEM_]], [[RES_1_]][] : memref -// CHECK-DAG: [[LOOP_3_:%.+]] = krnl.define_loops 1 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[LOAD_RES_2_MEM_]] : i64 +// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index +// CHECK: krnl.store [[VAR_10_]], [[RES_3_]]{{.}}[[CST_0_3_]]{{.}} : memref<1xi64> +// CHECK-DAG: [[VAR_11_:%.+]] = builtin.unrealized_conversion_cast [[RES_3_]] : memref<1xi64> to tensor<1xi64> +// CHECK-DAG: [[VAR_12_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : memref to memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_13_:%.+]] = builtin.unrealized_conversion_cast [[VAR_11_]] : tensor<1xi64> to memref<1xi64> +// CHECK-DAG: [[LOAD_VAR_12_MEM_:%.+]] = krnl.load [[VAR_12_]][] : memref +// CHECK: krnl.store [[LOAD_VAR_12_MEM_]], [[RES_1_]][] : memref +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK-DAG: [[CST_0_4_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_1_6_:%.+]] = arith.constant 1 : index -// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 0 to 1){ -// CHECK: [[VAR_14_1_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[VAR_11_]]{{.}}[[VAR_14_1_]]{{.}} : memref<1xi64> -// CHECK: krnl.store [[LOAD_RES_MEM_1_]], [[RES_]]{{.}}[[VAR_14_1_]]{{.}} : memref<1xi64> +// CHECK-DAG: [[CST_1_4_:%.+]] = arith.constant 1 : index +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 1){ +// CHECK: [[VAR_16_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_13_MEM_:%.+]] = krnl.load [[VAR_13_]]{{.}}[[VAR_16_]]{{.}} : memref<1xi64> +// CHECK: krnl.store [[LOAD_VAR_13_MEM_]], [[RES_]]{{.}}[[VAR_16_]]{{.}} : memref<1xi64> // CHECK: } // CHECK: }) : () -> () // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir index 2d35ed394f..5e149d2d96 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir @@ -183,6 +183,7 @@ func.func private @test_elementwise_op_with_array_and_scalar_values_2(%arg0 : te // ----- // SIMD for the lowest dim; possible broadcast for the top 2 dims. + func.func @roberta_partial_simd_1dim_v1(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor @@ -204,24 +205,30 @@ func.func @roberta_partial_simd_1dim_v1(%arg0: tensor, %arg1: tenso // CHECK-DAG: [[VAR_1_:%.+]] = affine.max [[MAP_0_]](){{.}}[[VAR_dim_0_]], [[VAR_dim_2_]]{{.}} // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]], [[VAR_1_]]) {{.*}}: memref -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]](){{.}}[[VAR_0_]], [[VAR_1_]]{{.}}, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 768){ -// CHECK-DAG: [[VAR_3_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]](){{.}}[[VAR_0_]], [[VAR_1_]]{{.}}){ +// CHECK-DAG: [[VAR_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) // CHECK-DAG: [[VAR_4_:%.+]] = arith.cmpi sgt, [[VAR_dim_]], [[CST_1_]] : index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_5_:%.+]] = arith.select [[VAR_4_]], [[VAR_3_]]#0, [[CST_0_]] : index // CHECK-DAG: [[VAR_6_:%.+]] = arith.cmpi sgt, [[VAR_dim_0_]], [[CST_1_]] : index -// CHECK: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_3_]]#1, [[CST_0_]] : index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_5_]], [[VAR_7_]], [[VAR_3_]]#2] : memref, vector<32xf32> -// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpi sgt, [[VAR_dim_1_]], [[CST_1_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = arith.select [[VAR_9_]], [[VAR_3_]]#0, [[CST_0_]] : index -// CHECK-DAG: [[VAR_11_:%.+]] = arith.cmpi sgt, [[VAR_dim_2_]], [[CST_1_]] : index -// CHECK: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[VAR_3_]]#1, [[CST_0_]] : index -// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[VAR_10_]], [[VAR_12_]], [[VAR_3_]]#2] : memref, vector<32xf32> -// CHECK: [[VAR_14_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_14_]], [[RES_]]{{.}}[[VAR_3_]]#0, [[VAR_3_]]#1, [[VAR_3_]]#2] : memref, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_3_]]#1, [[CST_0_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpi sgt, [[VAR_dim_1_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_3_]]#0, [[CST_0_]] : index +// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpi sgt, [[VAR_dim_2_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_10_]], [[VAR_3_]]#1, [[CST_0_]] : index +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 768){ +// CHECK: [[VAR_13_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_5_]], [[VAR_7_]], [[VAR_13_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[VAR_9_]], [[VAR_11_]], [[VAR_13_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_16_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_16_]], [[RES_]]{{.}}[[VAR_3_]]#0, [[VAR_3_]]#1, [[VAR_13_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -230,6 +237,7 @@ func.func @roberta_partial_simd_1dim_v1(%arg0: tensor, %arg1: tenso // ----- // Same as above, but now the second arg is "constant" in the top 2 dims; and thus no need for broadcast select for the top 2 either. + func.func @roberta_partial_simd_1dim_v2(%arg0: tensor, %arg1: tensor<768xf32>) -> tensor { %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor<768xf32>) -> tensor return %0 : tensor @@ -245,14 +253,18 @@ func.func @roberta_partial_simd_1dim_v2(%arg0: tensor, %arg1: tenso // CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]], [[VAR_dim_]]_0) {{.*}}: memref -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0], [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 768){ -// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref, vector<32xf32> -// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[VAR_1_]]#2] : memref<768xf32>, vector<32xf32> -// CHECK: [[VAR_4_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref, vector<32xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ +// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 768){ +// CHECK: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_3_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[VAR_3_]]{{.}} : memref<768xf32>, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_3_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -261,6 +273,7 @@ func.func @roberta_partial_simd_1dim_v2(%arg0: tensor, %arg1: tenso // ----- // same as above, 2nd param has a useless 1 added in front. + func.func @roberta_partial_simd_1dim_v3(%arg0: tensor, %arg1: tensor<1x768xf32>) -> tensor { %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor<1x768xf32>) -> tensor return %0 : tensor @@ -276,14 +289,18 @@ func.func @roberta_partial_simd_1dim_v3(%arg0: tensor, %arg1: tenso // CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]], [[VAR_dim_]]_0) {{.*}}: memref -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0], [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 768){ -// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref, vector<32xf32> -// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[CST_0_]], [[VAR_1_]]#2] : memref<1x768xf32>, vector<32xf32> -// CHECK: [[VAR_4_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref, vector<32xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ +// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 768){ +// CHECK: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_3_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[CST_0_]], [[VAR_3_]]{{.}} : memref<1x768xf32>, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_3_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -292,6 +309,7 @@ func.func @roberta_partial_simd_1dim_v3(%arg0: tensor, %arg1: tenso // ----- // same as above, 2nd param has 2 useless 1 added in front. + func.func @roberta_partial_simd_1dim_v4(%arg0: tensor, %arg1: tensor<1x1x768xf32>) -> tensor { %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor<1x1x768xf32>) -> tensor return %0 : tensor @@ -307,14 +325,18 @@ func.func @roberta_partial_simd_1dim_v4(%arg0: tensor, %arg1: tenso // CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]], [[VAR_dim_]]_0) {{.*}}: memref -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0], [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 768){ -// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref, vector<32xf32> -// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[CST_0_]], [[CST_0_]], [[VAR_1_]]#2] : memref<1x1x768xf32>, vector<32xf32> -// CHECK: [[VAR_4_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref, vector<32xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ +// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 768){ +// CHECK: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_3_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[CST_0_]], [[CST_0_]], [[VAR_3_]]{{.}} : memref<1x1x768xf32>, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_3_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -503,6 +525,7 @@ func.func @roberta_partial_simd_1dim_scalar(%arg0: tensor, %arg1: t // ----- // has ?x? in the first 2 dims for both params; collapse of the lowest 2 dims. + func.func @roberta_partial_simd_2dim_v1(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor @@ -544,24 +567,30 @@ func.func @roberta_partial_simd_2dim_v1(%arg0: tensor, %arg1: tens // CHECK: affine.store [[VAR_1_]], [[RES_3_]][1] : memref<3xindex> // CHECK: affine.store [[CST_768_]], [[RES_3_]][2] : memref<3xindex> // CHECK-DAG: [[VAR_reshape_11_:%.+]] = memref.reshape [[RES_]]([[RES_]]_10) : (memref, memref<3xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]](){{.}}[[VAR_0_]], [[VAR_1_]]{{.}}, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 768){ -// CHECK-DAG: [[VAR_3_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]](){{.}}[[VAR_0_]], [[VAR_1_]]{{.}}){ +// CHECK-DAG: [[VAR_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) // CHECK-DAG: [[VAR_4_:%.+]] = arith.cmpi sgt, [[VAR_dim_]], [[CST_1_]] : index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_5_:%.+]] = arith.select [[VAR_4_]], [[VAR_3_]]#0, [[CST_0_]] : index // CHECK-DAG: [[VAR_6_:%.+]] = arith.cmpi sgt, [[VAR_dim_0_]], [[CST_1_]] : index -// CHECK: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_3_]]#1, [[CST_0_]] : index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_5_]], [[VAR_7_]], [[VAR_3_]]#2] : memref, vector<32xf32> -// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpi sgt, [[VAR_dim_1_]], [[CST_1_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = arith.select [[VAR_9_]], [[VAR_3_]]#0, [[CST_0_]] : index -// CHECK-DAG: [[VAR_11_:%.+]] = arith.cmpi sgt, [[VAR_dim_2_]], [[CST_1_]] : index -// CHECK: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[VAR_3_]]#1, [[CST_0_]] : index -// CHECK: [[LOAD_VAR_reshape_9_MEM_:%.+]] = vector.load [[VAR_reshape_9_]]{{.}}[[VAR_10_]], [[VAR_12_]], [[VAR_3_]]#2] : memref, vector<32xf32> -// CHECK: [[VAR_14_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[LOAD_VAR_reshape_9_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_14_]], [[VAR_reshape_11_]]{{.}}[[VAR_3_]]#0, [[VAR_3_]]#1, [[VAR_3_]]#2] : memref, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_3_]]#1, [[CST_0_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpi sgt, [[VAR_dim_1_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_3_]]#0, [[CST_0_]] : index +// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpi sgt, [[VAR_dim_2_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_10_]], [[VAR_3_]]#1, [[CST_0_]] : index +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 768){ +// CHECK: [[VAR_13_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_5_]], [[VAR_7_]], [[VAR_13_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_9_MEM_:%.+]] = vector.load [[VAR_reshape_9_]]{{.}}[[VAR_9_]], [[VAR_11_]], [[VAR_13_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_16_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[LOAD_VAR_reshape_9_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_16_]], [[VAR_reshape_11_]]{{.}}[[VAR_3_]]#0, [[VAR_3_]]#1, [[VAR_13_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -570,6 +599,7 @@ func.func @roberta_partial_simd_2dim_v1(%arg0: tensor, %arg1: tens // ----- // has ?x2 and ?x? in the first 2 dims + func.func @roberta_partial_simd_2dim_v2(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor @@ -606,21 +636,27 @@ func.func @roberta_partial_simd_2dim_v2(%arg0: tensor, %arg1: tens // CHECK: affine.store [[CST_2_]], [[RES_3_]][1] : memref<3xindex> // CHECK: affine.store [[CST_768_]], [[RES_3_]][2] : memref<3xindex> // CHECK-DAG: [[VAR_reshape_9_:%.+]] = memref.reshape [[RES_]]([[RES_]]_8) : (memref, memref<3xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 768){ -// CHECK-DAG: [[VAR_2_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ +// CHECK-DAG: [[VAR_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) // CHECK-DAG: [[VAR_3_:%.+]] = arith.cmpi sgt, [[VAR_dim_]], [[CST_1_]] : index -// CHECK: [[VAR_4_:%.+]] = arith.select [[VAR_3_]], [[VAR_2_]]#0, [[CST_0_]] : index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]], [[VAR_2_]]#1, [[VAR_2_]]#2] : memref, vector<32xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = arith.cmpi sgt, [[VAR_dim_0_]], [[CST_1_]] : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_2_]]#0, [[CST_0_]] : index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpi sgt, [[VAR_dim_1_]], [[CST_1_]] : index -// CHECK: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_2_]]#1, [[CST_0_]] : index -// CHECK: [[LOAD_VAR_reshape_7_MEM_:%.+]] = vector.load [[VAR_reshape_7_]]{{.}}[[VAR_7_]], [[VAR_9_]], [[VAR_2_]]#2] : memref, vector<32xf32> -// CHECK: [[VAR_11_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[LOAD_VAR_reshape_7_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_11_]], [[VAR_reshape_9_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2] : memref, vector<32xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = arith.select [[VAR_3_]], [[VAR_2_]]#0, [[CST_0_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.cmpi sgt, [[VAR_dim_0_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = arith.select [[VAR_5_]], [[VAR_2_]]#0, [[CST_0_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpi sgt, [[VAR_dim_1_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_2_]]#1, [[CST_0_]] : index +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 768){ +// CHECK: [[VAR_10_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]], [[VAR_2_]]#1, [[VAR_10_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_7_MEM_:%.+]] = vector.load [[VAR_reshape_7_]]{{.}}[[VAR_6_]], [[VAR_8_]], [[VAR_10_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_13_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[LOAD_VAR_reshape_7_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_13_]], [[VAR_reshape_9_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_10_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -629,6 +665,7 @@ func.func @roberta_partial_simd_2dim_v2(%arg0: tensor, %arg1: tens // ----- // has ?x2 and ?x1 in the first 2 dims + func.func @roberta_partial_simd_2dim_v3(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor @@ -663,18 +700,24 @@ func.func @roberta_partial_simd_2dim_v3(%arg0: tensor, %arg1: tens // CHECK: affine.store [[CST_2_]], [[RES_3_]][1] : memref<3xindex> // CHECK: affine.store [[CST_768_]], [[RES_3_]][2] : memref<3xindex> // CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_]]([[RES_]]_6) : (memref, memref<3xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 768){ -// CHECK-DAG: [[VAR_2_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ +// CHECK-DAG: [[VAR_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) // CHECK-DAG: [[VAR_3_:%.+]] = arith.cmpi sgt, [[VAR_dim_]], [[CST_1_]] : index -// CHECK: [[VAR_4_:%.+]] = arith.select [[VAR_3_]], [[VAR_2_]]#0, [[CST_0_]] : index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]], [[VAR_2_]]#1, [[VAR_2_]]#2] : memref, vector<32xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = arith.cmpi sgt, [[VAR_dim_0_]], [[CST_1_]] : index -// CHECK: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_2_]]#0, [[CST_0_]] : index -// CHECK: [[LOAD_VAR_reshape_5_MEM_:%.+]] = vector.load [[VAR_reshape_5_]]{{.}}[[VAR_7_]], [[CST_0_]], [[VAR_2_]]#2] : memref, vector<32xf32> -// CHECK: [[VAR_9_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[LOAD_VAR_reshape_5_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_9_]], [[VAR_reshape_7_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2] : memref, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.select [[VAR_3_]], [[VAR_2_]]#0, [[CST_0_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.cmpi sgt, [[VAR_dim_0_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = arith.select [[VAR_5_]], [[VAR_2_]]#0, [[CST_0_]] : index +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 768){ +// CHECK: [[VAR_8_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]], [[VAR_2_]]#1, [[VAR_8_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_5_MEM_:%.+]] = vector.load [[VAR_reshape_5_]]{{.}}[[VAR_6_]], [[CST_0_]], [[VAR_8_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_11_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[LOAD_VAR_reshape_5_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_11_]], [[VAR_reshape_7_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_8_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -682,17 +725,20 @@ func.func @roberta_partial_simd_2dim_v3(%arg0: tensor, %arg1: tens // ----- -// Currently does partial simd only when partial simd static size is a mutiple of VL. +// static size is not mod archVL, will do some SIMD some scalar. + func.func @roberta_partial_simd_2dim_not_0_mod_vl(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0, s1] -> (s1, s0)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0, s1] -> (s1)> // CHECK-LABEL: func.func @roberta_partial_simd_2dim_not_0_mod_vl // CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref) -> memref { +// CHECK-DAG: [[CST_665_:%.+]] = arith.constant 665 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_2660_:%.+]] = arith.constant 2660 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-NOT: separator of consecutive DAGs @@ -703,33 +749,71 @@ func.func @roberta_partial_simd_2dim_not_0_mod_vl(%arg0: tensor, % // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_0_:%.+]] = affine.max [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_1] // CHECK-DAG: [[VAR_1_:%.+]] = affine.max [[MAP_0_]](){{.}}[[VAR_dim_0_]], [[VAR_dim_2_]]{{.}} -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]], [[VAR_1_]]) {{.*}}: memref -// CHECK-DAG: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]], [[VAR_dim_]]_1, [[VAR_dim_]]_0, [[VAR_dim_]]_2, [[VAR_0_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_2_]]([[VAR_dim_]], [[VAR_dim_]]_1, [[VAR_dim_]]_0, [[VAR_dim_]]_2, [[VAR_0_]], [[VAR_1_]]), [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 95, [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 7){ -// CHECK-DAG: [[VAR_3_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) -// CHECK-DAG: [[VAR_4_:%.+]] = arith.cmpi sgt, [[VAR_dim_]], [[CST_1_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_5_:%.+]] = arith.select [[VAR_4_]], [[VAR_3_]]#0, [[CST_0_]] : index -// CHECK-DAG: [[VAR_6_:%.+]] = arith.cmpi sgt, [[VAR_dim_0_]], [[CST_1_]] : index -// CHECK: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_3_]]#1, [[CST_0_]] : index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_5_]], [[VAR_7_]], [[VAR_3_]]#2, [[VAR_3_]]#3] : memref -// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpi sgt, [[VAR_dim_1_]], [[CST_1_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = arith.select [[VAR_9_]], [[VAR_3_]]#0, [[CST_0_]] : index -// CHECK-DAG: [[VAR_11_:%.+]] = arith.cmpi sgt, [[VAR_dim_2_]], [[CST_1_]] : index -// CHECK: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[VAR_3_]]#1, [[CST_0_]] : index -// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_10_]], [[VAR_12_]], [[VAR_3_]]#2, [[VAR_3_]]#3] : memref -// CHECK: [[VAR_14_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: krnl.store [[VAR_14_]], [[RES_]]{{.}}[[VAR_3_]]#0, [[VAR_3_]]#1, [[VAR_3_]]#2, [[VAR_3_]]#3] : memref +// CHECK: [[VAR_2_:%.+]] = arith.muli [[VAR_0_]], [[VAR_1_]] : index +// CHECK: [[VAR_3_:%.+]] = arith.muli [[VAR_2_]], [[CST_2660_]] : index +// CHECK: [[VAR_4_:%.+]] = arith.addi [[VAR_3_]], [[CST_128_]] : index +// CHECK: [[RES_:%.+]] = memref.alloc([[VAR_4_]]) {{.*}}: memref +// CHECK-DAG: [[VAR_view_:%.+]] = memref.view [[RES_]]{{.}}[[CST_0_]]{{.}}{{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref to memref +// CHECK-DAG: [[VAR_dim_3_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[VAR_dim_4_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[VAR_dim_3_]], [[RES_1_]][0] : memref<3xindex> +// CHECK: affine.store [[VAR_dim_4_]], [[RES_1_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_665_]], [[RES_1_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref, memref<3xindex>) -> memref +// CHECK-DAG: [[VAR_dim_6_:%.+]] = memref.dim [[PARAM_1_]], [[CST_0_]] : memref +// CHECK-DAG: [[VAR_dim_7_:%.+]] = memref.dim [[PARAM_1_]], [[CST_1_]] : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[VAR_dim_6_]], [[RES_2_]][0] : memref<3xindex> +// CHECK: affine.store [[VAR_dim_7_]], [[RES_2_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_665_]], [[RES_2_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_9_:%.+]] = memref.reshape [[PARAM_1_]]([[RES_2_]]) : (memref, memref<3xindex>) -> memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[VAR_0_]], [[RES_3_]][0] : memref<3xindex> +// CHECK: affine.store [[VAR_1_]], [[RES_3_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_665_]], [[RES_3_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_11_:%.+]] = memref.reshape [[VAR_view_]]([[RES_3_]]) : (memref, memref<3xindex>) -> memref +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]](){{.}}[[VAR_0_]], [[VAR_1_]]{{.}}){ +// CHECK-DAG: [[VAR_6_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpi sgt, [[VAR_dim_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_6_]]#0, [[CST_0_]] : index +// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpi sgt, [[VAR_dim_0_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_10_:%.+]] = arith.select [[VAR_9_]], [[VAR_6_]]#1, [[CST_0_]] : index +// CHECK-DAG: [[VAR_11_:%.+]] = arith.cmpi sgt, [[VAR_dim_1_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[VAR_6_]]#0, [[CST_0_]] : index +// CHECK-DAG: [[VAR_13_:%.+]] = arith.cmpi sgt, [[VAR_dim_2_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_14_:%.+]] = arith.select [[VAR_13_]], [[VAR_6_]]#1, [[CST_0_]] : index +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 634){ +// CHECK: [[VAR_17_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]], [[VAR_10_]], [[VAR_17_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_9_MEM_:%.+]] = vector.load [[VAR_reshape_9_]]{{.}}[[VAR_12_]], [[VAR_14_]], [[VAR_17_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_20_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[LOAD_VAR_reshape_9_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_20_]], [[VAR_reshape_11_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1, [[VAR_17_]]{{.}} : memref, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_3_:%.+]] = 640 to 665){ +// CHECK: [[VAR_17_1_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_8_]], [[VAR_10_]], [[VAR_17_1_]]{{.}} : memref +// CHECK-DAG: [[LOAD_VAR_reshape_9_MEM_1_:%.+]] = krnl.load [[VAR_reshape_9_]]{{.}}[[VAR_12_]], [[VAR_14_]], [[VAR_17_1_]]{{.}} : memref +// CHECK: [[VAR_20_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_1_]], [[LOAD_VAR_reshape_9_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_20_1_]], [[VAR_reshape_11_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1, [[VAR_17_1_]]{{.}} : memref +// CHECK: } // CHECK: } -// CHECK: return [[RES_]] : memref +// CHECK: return [[VAR_view_]] : memref // CHECK: } } // ----- // Tests found in roberta when there are leading ones in tensors; they are ignored for the broadcasting as they contribute nothing. + func.func @add_from_test_with_1_1(%arg0 : tensor<1x1x128xf32>, %arg1 : tensor<128xf32>) -> tensor<*xf32> { %0 = "onnx.Add"(%arg0, %arg1) : (tensor<1x1x128xf32>, tensor<128xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> @@ -737,16 +821,24 @@ func.func @add_from_test_with_1_1(%arg0 : tensor<1x1x128xf32>, %arg1 : tensor<12 // mlir2FileCheck.py // CHECK-LABEL: func.func @add_from_test_with_1_1 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x1x128xf32>, [[PARAM_1_:%.+]]: memref<128xf32>) -> memref<1x1x128xf32> { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x1x128xf32> -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 1, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 128){ -// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[CST_0_]], [[CST_0_]], [[VAR_1_]]#2] : memref<1x1x128xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[VAR_1_]]#2] : memref<128xf32>, vector<32xf32> -// CHECK: [[VAR_4_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref<1x1x128xf32>, vector<32xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_1_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref<1x1x128xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_2_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_2_:%.+]] = memref.reshape [[RES_]]([[RES_]]_1) : (memref<1x1x128xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 128){ +// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[VAR_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[VAR_4_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_4_]], [[VAR_reshape_2_]]{{.}}[[VAR_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref<1x1x128xf32> // CHECK: } @@ -755,6 +847,7 @@ func.func @add_from_test_with_1_1(%arg0 : tensor<1x1x128xf32>, %arg1 : tensor<12 // ----- // Same, with the 2 params swapped, to make sure order does not matter. + func.func @add_from_test_with_1_1_swapped(%arg0 : tensor<128xf32>, %arg1 : tensor<1x1x128xf32>) -> tensor<*xf32> { %0 = "onnx.Add"(%arg0, %arg1) : (tensor<128xf32>, tensor<1x1x128xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> @@ -762,16 +855,24 @@ func.func @add_from_test_with_1_1_swapped(%arg0 : tensor<128xf32>, %arg1 : tenso // mlir2FileCheck.py // CHECK-LABEL: func.func @add_from_test_with_1_1_swapped // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<128xf32>, [[PARAM_1_:%.+]]: memref<1x1x128xf32>) -> memref<1x1x128xf32> { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x1x128xf32> -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 1, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 128){ -// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#2] : memref<128xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = vector.load [[PARAM_1_]]{{.}}[[CST_0_]], [[CST_0_]], [[VAR_1_]]#2] : memref<1x1x128xf32>, vector<32xf32> -// CHECK: [[VAR_4_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref<1x1x128xf32>, vector<32xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_1_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_1_]]([[RES_1_]]) : (memref<1x1x128xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_2_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_2_:%.+]] = memref.reshape [[RES_]]([[RES_]]_1) : (memref<1x1x128xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 128){ +// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[VAR_4_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_4_]], [[VAR_reshape_2_]]{{.}}[[VAR_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref<1x1x128xf32> // CHECK: } @@ -1172,6 +1273,7 @@ func.func @test_prelu_broadcast2(%arg0: tensor<3x4x5xf32>, %arg1: tensor<1x5xf32 // ----- // case where only the lowest 2 dims are splatted. + func.func @add_partial_splat(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<3x1x1xf32>) -> tensor<*xf32> { %0 = "onnx.Add"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<3x1x1xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> @@ -1199,15 +1301,19 @@ func.func @add_partial_splat(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<3x1x1xf32 // CHECK: affine.store [[CST_3_]], [[RES_3_]][1] : memref<3xindex> // CHECK: affine.store [[CST_20_]], [[RES_3_]][2] : memref<3xindex> // CHECK-DAG: [[VAR_reshape_4_:%.+]] = memref.reshape [[RES_]]([[RES_]]_3) : (memref<2x3x4x5xf32>, memref<3xindex>) -> memref<2x3x20xf32> -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#2 20 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 2, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 20){ -// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref<2x3x20xf32>, vector<20xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_2_MEM_:%.+]] = krnl.load [[VAR_reshape_2_]]{{.}}[[VAR_1_]]#1, [[CST_0_]]{{.}} : memref<3x1xf32> -// CHECK: [[VAR_4_:%.+]] = vector.splat [[LOAD_VAR_reshape_2_MEM_]] : vector<20xf32> -// CHECK: [[VAR_5_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[VAR_4_]] : vector<20xf32> -// CHECK: vector.store [[VAR_5_]], [[VAR_reshape_4_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref<2x3x20xf32>, vector<20xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 2, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3){ +// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 20 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 20){ +// CHECK: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_3_]]{{.}} : memref<2x3x20xf32>, vector<20xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_2_MEM_:%.+]] = krnl.load [[VAR_reshape_2_]]{{.}}[[VAR_1_]]#1, [[CST_0_]]{{.}} : memref<3x1xf32> +// CHECK: [[VAR_6_:%.+]] = vector.splat [[LOAD_VAR_reshape_2_MEM_]] : vector<20xf32> +// CHECK: [[VAR_7_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[VAR_6_]] : vector<20xf32> +// CHECK: vector.store [[VAR_7_]], [[VAR_reshape_4_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_3_]]{{.}} : memref<2x3x20xf32>, vector<20xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref<2x3x4x5xf32> // CHECK: } @@ -1216,6 +1322,7 @@ func.func @add_partial_splat(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<3x1x1xf32 // ----- // Test onnx.Erf lowering from onnx to kerneL + func.func @test_erf(%arg0: tensor) -> (tensor<*xf32>) { %0 = "onnx.Erf"(%arg0): (tensor) -> (tensor<*xf32>) return %0 : tensor<*xf32> @@ -1223,7 +1330,7 @@ func.func @test_erf(%arg0: tensor) -> (tensor<*xf32>) { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func @test_erf // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK: [[CST_0_:%.+]] = arith.constant 0 : index @@ -1240,14 +1347,16 @@ func.func @test_erf(%arg0: tensor) -> (tensor<*xf32>) { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = math.erf [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = math.erf [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1305,6 +1414,7 @@ func.func private @test_tanh(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_sinh(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Sinh"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1312,7 +1422,7 @@ func.func private @test_sinh(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_sinh // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.000000e+00> : vector<32xf32> @@ -1331,18 +1441,20 @@ func.func private @test_sinh(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_4_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_1]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = arith.subf [[VAR_cst_0_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: [[VAR_8_:%.+]] = math.exp [[VAR_6_]] : vector<32xf32> -// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : vector<32xf32> -// CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_9_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: vector.store [[VAR_10_]], [[VAR_reshape_4_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_4_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_1, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subf [[VAR_cst_0_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = math.exp [[VAR_6_]] : vector<32xf32> +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : vector<32xf32> +// CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_9_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: vector.store [[VAR_10_]], [[VAR_reshape_4_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1350,6 +1462,7 @@ func.func private @test_sinh(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_cosh(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Cosh"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1357,7 +1470,7 @@ func.func private @test_cosh(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_cosh // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.000000e+00> : vector<32xf32> @@ -1376,18 +1489,20 @@ func.func private @test_cosh(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_4_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_1]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = arith.subf [[VAR_cst_0_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: [[VAR_8_:%.+]] = math.exp [[VAR_6_]] : vector<32xf32> -// CHECK: [[VAR_9_:%.+]] = arith.addf [[VAR_7_]], [[VAR_8_]] : vector<32xf32> -// CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_9_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: vector.store [[VAR_10_]], [[VAR_reshape_4_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_4_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_1, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subf [[VAR_cst_0_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = math.exp [[VAR_6_]] : vector<32xf32> +// CHECK: [[VAR_9_:%.+]] = arith.addf [[VAR_7_]], [[VAR_8_]] : vector<32xf32> +// CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_9_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: vector.store [[VAR_10_]], [[VAR_reshape_4_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1471,6 +1586,7 @@ func.func private @test_log(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_sigmoid(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Sigmoid"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1478,7 +1594,7 @@ func.func private @test_sigmoid(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_sigmoid // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.000000e+00> : vector<32xf32> @@ -1497,17 +1613,19 @@ func.func private @test_sigmoid(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_4_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_1]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_cst_0_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: [[VAR_7_:%.+]] = math.exp [[VAR_6_]] : vector<32xf32> -// CHECK: [[VAR_8_:%.+]] = arith.addf [[VAR_7_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: [[VAR_9_:%.+]] = arith.divf [[VAR_cst_]], [[VAR_8_]] : vector<32xf32> -// CHECK: vector.store [[VAR_9_]], [[VAR_reshape_4_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_4_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_1, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_cst_0_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: [[VAR_7_:%.+]] = math.exp [[VAR_6_]] : vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = arith.addf [[VAR_7_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: [[VAR_9_:%.+]] = arith.divf [[VAR_cst_]], [[VAR_8_]] : vector<32xf32> +// CHECK: vector.store [[VAR_9_]], [[VAR_reshape_4_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1523,7 +1641,7 @@ func.func private @test_relu(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_relu // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : vector<32xf32> @@ -1541,14 +1659,16 @@ func.func private @test_relu(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = arith.maxnumf [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = arith.maxnumf [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1556,6 +1676,7 @@ func.func private @test_relu(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_elu(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1563,7 +1684,7 @@ func.func private @test_elu(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_elu // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.000000e+00> : vector<32xf32> @@ -1583,18 +1704,20 @@ func.func private @test_elu(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpf olt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xf32> -// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_6_]], [[VAR_cst_0_]] : vector<32xf32> -// CHECK: [[VAR_9_:%.+]] = arith.mulf [[VAR_8_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: [[VAR_10_:%.+]] = arith.select [[VAR_7_]], [[VAR_9_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xi1>, vector<32xf32> -// CHECK: vector.store [[VAR_10_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpf olt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_6_]], [[VAR_cst_0_]] : vector<32xf32> +// CHECK: [[VAR_9_:%.+]] = arith.mulf [[VAR_8_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: [[VAR_10_:%.+]] = arith.select [[VAR_7_]], [[VAR_9_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xi1>, vector<32xf32> +// CHECK: vector.store [[VAR_10_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1602,6 +1725,7 @@ func.func private @test_elu(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1609,7 +1733,7 @@ func.func private @test_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_leakyrelu // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK: [[CST_0_:%.+]] = arith.constant 0 : index @@ -1626,13 +1750,15 @@ func.func private @test_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: vector.store [[LOAD_VAR_reshape_MEM_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1640,6 +1766,7 @@ func.func private @test_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_selu(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1647,7 +1774,7 @@ func.func private @test_selu(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_selu // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.000000e+00> : vector<32xf32> @@ -1667,18 +1794,20 @@ func.func private @test_selu(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpf ogt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xf32> -// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_6_]], [[VAR_cst_0_]] : vector<32xf32> -// CHECK: [[VAR_9_:%.+]] = arith.select [[VAR_7_]], [[LOAD_VAR_reshape_MEM_]], [[VAR_8_]] : vector<32xi1>, vector<32xf32> -// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: vector.store [[VAR_10_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpf ogt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_6_]], [[VAR_cst_0_]] : vector<32xf32> +// CHECK: [[VAR_9_:%.+]] = arith.select [[VAR_7_]], [[LOAD_VAR_reshape_MEM_]], [[VAR_8_]] : vector<32xi1>, vector<32xf32> +// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: vector.store [[VAR_10_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1694,7 +1823,7 @@ func.func private @test_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_hardsigmoid // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.000000e+00> : vector<32xf32> @@ -1714,16 +1843,18 @@ func.func private @test_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: [[VAR_7_:%.+]] = arith.maxnumf [[VAR_6_]], [[VAR_cst_1_]] : vector<32xf32> -// CHECK: [[VAR_8_:%.+]] = arith.minnumf [[VAR_7_]], [[VAR_cst_0_]] : vector<32xf32> -// CHECK: vector.store [[VAR_8_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: [[VAR_7_:%.+]] = arith.maxnumf [[VAR_6_]], [[VAR_cst_1_]] : vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = arith.minnumf [[VAR_7_]], [[VAR_cst_0_]] : vector<32xf32> +// CHECK: vector.store [[VAR_8_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1731,6 +1862,7 @@ func.func private @test_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_reciprocal(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Reciprocal"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1738,7 +1870,7 @@ func.func private @test_reciprocal(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_reciprocal // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.000000e+00> : vector<32xf32> @@ -1756,14 +1888,16 @@ func.func private @test_reciprocal(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = arith.divf [[VAR_cst_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = arith.divf [[VAR_cst_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1771,6 +1905,7 @@ func.func private @test_reciprocal(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_softplus(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Softplus"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1778,7 +1913,7 @@ func.func private @test_softplus(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_softplus // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.000000e+00> : vector<32xf32> @@ -1796,16 +1931,18 @@ func.func private @test_softplus(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: [[VAR_7_:%.+]] = arith.addf [[VAR_6_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: [[VAR_8_:%.+]] = math.log [[VAR_7_]] : vector<32xf32> -// CHECK: vector.store [[VAR_8_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = math.exp [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: [[VAR_7_:%.+]] = arith.addf [[VAR_6_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = math.log [[VAR_7_]] : vector<32xf32> +// CHECK: vector.store [[VAR_8_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1813,6 +1950,7 @@ func.func private @test_softplus(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_softsign(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Softsign"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1820,7 +1958,7 @@ func.func private @test_softsign(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_softsign // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.000000e+00> : vector<32xf32> @@ -1838,16 +1976,18 @@ func.func private @test_softsign(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = math.absf [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: [[VAR_7_:%.+]] = arith.addf [[VAR_6_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: [[VAR_8_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_]], [[VAR_7_]] : vector<32xf32> -// CHECK: vector.store [[VAR_8_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = math.absf [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: [[VAR_7_:%.+]] = arith.addf [[VAR_6_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_]], [[VAR_7_]] : vector<32xf32> +// CHECK: vector.store [[VAR_8_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1855,6 +1995,7 @@ func.func private @test_softsign(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Sqrt"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1862,7 +2003,7 @@ func.func private @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_sqrt // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK: [[CST_0_:%.+]] = arith.constant 0 : index @@ -1879,14 +2020,16 @@ func.func private @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = math.sqrt [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = math.sqrt [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1894,6 +2037,7 @@ func.func private @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_sign_f(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Sign"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1901,7 +2045,7 @@ func.func private @test_sign_f(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_sign_f // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-1.000000e+00> : vector<32xf32> @@ -1921,17 +2065,19 @@ func.func private @test_sign_f(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = arith.cmpf ogt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_cst_0_]], [[VAR_cst_]] : vector<32xi1>, vector<32xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpf oeq, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xf32> -// CHECK: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_cst_1_]], [[VAR_7_]] : vector<32xi1>, vector<32xf32> -// CHECK: vector.store [[VAR_9_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = arith.cmpf ogt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_cst_0_]], [[VAR_cst_]] : vector<32xi1>, vector<32xf32> +// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpf oeq, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xf32> +// CHECK: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_cst_1_]], [[VAR_7_]] : vector<32xi1>, vector<32xf32> +// CHECK: vector.store [[VAR_9_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1939,6 +2085,7 @@ func.func private @test_sign_f(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_sign_i(%arg0 : tensor) -> tensor<*xi32> { %0 = "onnx.Sign"(%arg0) : (tensor) -> tensor<*xi32> "func.return"(%0) : (tensor<*xi32>) -> () @@ -1946,7 +2093,7 @@ func.func private @test_sign_i(%arg0 : tensor) -> tensor<*xi32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_sign_i // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-1> : vector<32xi32> @@ -1966,17 +2113,19 @@ func.func private @test_sign_i(%arg0 : tensor) -> tensor<*xi32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xi32> -// CHECK: [[VAR_6_:%.+]] = arith.cmpi sgt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xi32> -// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_cst_0_]], [[VAR_cst_]] : vector<32xi1>, vector<32xi32> -// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpi eq, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xi32> -// CHECK: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_cst_1_]], [[VAR_7_]] : vector<32xi1>, vector<32xi32> -// CHECK: vector.store [[VAR_9_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xi32> +// CHECK: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_2, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xi32> +// CHECK: [[VAR_6_:%.+]] = arith.cmpi sgt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xi32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_cst_0_]], [[VAR_cst_]] : vector<32xi1>, vector<32xi32> +// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpi eq, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_1_]] : vector<32xi32> +// CHECK: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_cst_1_]], [[VAR_7_]] : vector<32xi1>, vector<32xi32> +// CHECK: vector.store [[VAR_9_]], [[VAR_reshape_5_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xi32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -1984,6 +2133,7 @@ func.func private @test_sign_i(%arg0 : tensor) -> tensor<*xi32> { // ----- + func.func private @test_abs_float(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Abs"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -1991,7 +2141,7 @@ func.func private @test_abs_float(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_abs_float // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK: [[CST_0_:%.+]] = arith.constant 0 : index @@ -2008,14 +2158,16 @@ func.func private @test_abs_float(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = math.absf [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = math.absf [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -2023,6 +2175,7 @@ func.func private @test_abs_float(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_abs_int(%arg0 : tensor) -> tensor<*xi32> { %0 = "onnx.Abs"(%arg0) : (tensor) -> tensor<*xi32> "func.return"(%0) : (tensor<*xi32>) -> () @@ -2030,7 +2183,7 @@ func.func private @test_abs_int(%arg0 : tensor) -> tensor<*xi32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_abs_int // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK: [[CST_0_:%.+]] = arith.constant 0 : index @@ -2047,14 +2200,16 @@ func.func private @test_abs_int(%arg0 : tensor) -> tensor<*xi32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xi32> -// CHECK: [[VAR_6_:%.+]] = math.absi [[LOAD_VAR_reshape_MEM_]] : vector<32xi32> -// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xi32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xi32> +// CHECK: [[VAR_6_:%.+]] = math.absi [[LOAD_VAR_reshape_MEM_]] : vector<32xi32> +// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xi32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -2062,6 +2217,7 @@ func.func private @test_abs_int(%arg0 : tensor) -> tensor<*xi32> { // ----- + func.func private @test_floor(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Floor"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -2069,7 +2225,7 @@ func.func private @test_floor(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_floor // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK: [[CST_0_:%.+]] = arith.constant 0 : index @@ -2086,14 +2242,16 @@ func.func private @test_floor(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = math.floor [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = math.floor [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } @@ -2101,6 +2259,7 @@ func.func private @test_floor(%arg0 : tensor) -> tensor<*xf32> { // ----- + func.func private @test_ceil(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Ceil"(%arg0) : (tensor) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -2108,7 +2267,7 @@ func.func private @test_ceil(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func private @test_ceil // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK: [[CST_0_:%.+]] = arith.constant 0 : index @@ -2125,14 +2284,16 @@ func.func private @test_ceil(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = math.ceil [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = math.ceil [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_parallel_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_parallel_canonicalize_O3.mlir index 662641b15f..336fbde153 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_parallel_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_parallel_canonicalize_O3.mlir @@ -14,7 +14,7 @@ func.func @test_relu_parallel(%arg0 : tensor) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> // CHECK-LABEL: func.func @test_relu_parallel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : vector<32xf32> @@ -32,15 +32,17 @@ func.func @test_relu_parallel(%arg0 : tensor) -> tensor<*xf32> { // CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> -// CHECK: [[VAR_6_:%.+]] = arith.maxnumf [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_]] : vector<32xf32> -// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_reshape_3_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0, [[VAR_2_]]{{.}}){ +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: [[VAR_6_:%.+]] = arith.maxnumf [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_]] : vector<32xf32> +// CHECK: vector.store [[VAR_6_]], [[VAR_reshape_3_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<32xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize.mlir index 72184f4996..0bb543ded4 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize.mlir @@ -17,70 +17,63 @@ func.func @test_matmulinteger_per_tensor(%arg0: tensor<16x32xui8>, %arg1: tensor // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x32xi32> // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 32){ -// CHECK: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref<16x32xui8> -// CHECK: [[VAR_9_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 -// CHECK: [[VAR_10_:%.+]] = arith.extui [[VAR_9_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref<16x32xi32> +// CHECK: [[VAR_11_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_11_]]#0, [[VAR_11_]]#1] : memref<16x32xui8> +// CHECK: [[VAR_13_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK: [[VAR_14_:%.+]] = arith.extui [[VAR_13_]] : i8 to i32 +// CHECK: krnl.store [[VAR_14_]], [[RES_]]{{.}}[[VAR_11_]]#0, [[VAR_11_]]#1] : memref<16x32xi32> // CHECK: } // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1xi32> -// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 1){ -// CHECK: [[VAR_7_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_7_1_]]{{.}} : memref<1xui8> -// CHECK: [[VAR_9_1_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_]] : ui8 to i8 -// CHECK: [[VAR_10_1_:%.+]] = arith.extui [[VAR_9_1_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_1_]], [[RES_1_]]{{.}}[[VAR_7_1_]]{{.}} : memref<1xi32> -// CHECK: } +// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xui8> +// CHECK: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 +// CHECK: [[VAR_3_:%.+]] = arith.extui [[VAR_2_]] : i8 to i32 +// CHECK: krnl.store [[VAR_3_]], [[RES_1_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<16x32xi32> -// CHECK-DAG: [[LOOP_2_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1) with ([[LOOP_2_]]#0 -> [[I_3_:%.+]] = 0 to 16, [[LOOP_2_]]#1 -> [[I_4_:%.+]] = 0 to 32){ -// CHECK: [[VAR_7_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_7_2_]]#0, [[VAR_7_2_]]#1] : memref<16x32xi32> -// CHECK-DAG: [[VAR_9_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> -// CHECK: [[VAR_10_2_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_]], [[VAR_9_1_]] : i32 -// CHECK: krnl.store [[VAR_10_2_]], [[RES_2_]]{{.}}[[VAR_7_2_]]#0, [[VAR_7_2_]]#1] : memref<16x32xi32> +// CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to 16, [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 32){ +// CHECK: [[VAR_11_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_11_1_]]#0, [[VAR_11_1_]]#1] : memref<16x32xi32> +// CHECK-DAG: [[VAR_13_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> +// CHECK: [[VAR_14_1_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_]], [[VAR_13_1_]] : i32 +// CHECK: krnl.store [[VAR_14_1_]], [[RES_2_]]{{.}}[[VAR_11_1_]]#0, [[VAR_11_1_]]#1] : memref<16x32xi32> // CHECK: } // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<32x64xi32> -// CHECK-DAG: [[LOOP_3_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_3_]]#0, [[LOOP_3_]]#1) with ([[LOOP_3_]]#0 -> [[I_5_:%.+]] = 0 to 32, [[LOOP_3_]]#1 -> [[I_6_:%.+]] = 0 to 64){ -// CHECK: [[VAR_7_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_3_]]#0, [[LOOP_3_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_7_3_]]#0, [[VAR_7_3_]]#1] : memref<32x64xui8> -// CHECK: [[VAR_9_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_1_]] : ui8 to i8 -// CHECK: [[VAR_10_3_:%.+]] = arith.extui [[VAR_9_2_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_3_]], [[RES_3_]]{{.}}[[VAR_7_3_]]#0, [[VAR_7_3_]]#1] : memref<32x64xi32> +// CHECK-DAG: [[LOOP_2_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1) with ([[LOOP_2_]]#0 -> [[I_4_:%.+]] = 0 to 32, [[LOOP_2_]]#1 -> [[I_5_:%.+]] = 0 to 64){ +// CHECK: [[VAR_11_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_11_2_]]#0, [[VAR_11_2_]]#1] : memref<32x64xui8> +// CHECK: [[VAR_13_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_]] : ui8 to i8 +// CHECK: [[VAR_14_2_:%.+]] = arith.extui [[VAR_13_2_]] : i8 to i32 +// CHECK: krnl.store [[VAR_14_2_]], [[RES_3_]]{{.}}[[VAR_11_2_]]#0, [[VAR_11_2_]]#1] : memref<32x64xi32> // CHECK: } // CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<1xi32> -// CHECK-DAG: [[LOOP_4_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_4_]]) with ([[LOOP_4_]] -> [[I_7_:%.+]] = 0 to 1){ -// CHECK: [[VAR_7_4_:%.+]] = krnl.get_induction_var_value([[LOOP_4_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[PARAM_3_]]{{.}}[[VAR_7_4_]]{{.}} : memref<1xui8> -// CHECK: [[VAR_9_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_1_]] : ui8 to i8 -// CHECK: [[VAR_10_4_:%.+]] = arith.extui [[VAR_9_3_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_4_]], [[RES_4_]]{{.}}[[VAR_7_4_]]{{.}} : memref<1xi32> -// CHECK: } +// CHECK-DAG: [[LOAD_PARAM_3_MEM_:%.+]] = krnl.load [[PARAM_3_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xui8> +// CHECK: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_3_MEM_]] : ui8 to i8 +// CHECK: [[VAR_8_:%.+]] = arith.extui [[VAR_7_]] : i8 to i32 +// CHECK: krnl.store [[VAR_8_]], [[RES_4_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<32x64xi32> -// CHECK-DAG: [[LOOP_5_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_5_]]#0, [[LOOP_5_]]#1) with ([[LOOP_5_]]#0 -> [[I_8_:%.+]] = 0 to 32, [[LOOP_5_]]#1 -> [[I_9_:%.+]] = 0 to 64){ -// CHECK: [[VAR_7_5_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_5_]]#0, [[LOOP_5_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_1_1_:%.+]] = krnl.load [[RES_3_]]{{.}}[[VAR_7_5_]]#0, [[VAR_7_5_]]#1] : memref<32x64xi32> -// CHECK-DAG: [[VAR_9_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> -// CHECK: [[VAR_10_5_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_1_1_]], [[VAR_9_3_]] : i32 -// CHECK: krnl.store [[VAR_10_5_]], [[RES_5_]]{{.}}[[VAR_7_5_]]#0, [[VAR_7_5_]]#1] : memref<32x64xi32> +// CHECK-DAG: [[LOOP_3_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_3_]]#0, [[LOOP_3_]]#1) with ([[LOOP_3_]]#0 -> [[I_6_:%.+]] = 0 to 32, [[LOOP_3_]]#1 -> [[I_7_:%.+]] = 0 to 64){ +// CHECK: [[VAR_11_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_3_]]#0, [[LOOP_3_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_3_]]{{.}}[[VAR_11_3_]]#0, [[VAR_11_3_]]#1] : memref<32x64xi32> +// CHECK-DAG: [[VAR_13_2_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> +// CHECK: [[VAR_14_3_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_1_]], [[VAR_13_2_]] : i32 +// CHECK: krnl.store [[VAR_14_3_]], [[RES_5_]]{{.}}[[VAR_11_3_]]#0, [[VAR_11_3_]]#1] : memref<32x64xi32> // CHECK: } // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<16x64xi32> -// CHECK-DAG: [[LOOP_6_:%.+]]:3 = krnl.define_loops 3 -// CHECK: krnl.iterate([[LOOP_6_]]#0, [[LOOP_6_]]#1) with ([[LOOP_6_]]#0 -> [[I_10_:%.+]] = 0 to 16, [[LOOP_6_]]#1 -> [[I_11_:%.+]] = 0 to 64, [[LOOP_6_]]#2 -> [[I_12_:%.+]] = 0 to 32){ -// CHECK-DAG: [[VAR_7_6_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_6_]]#0, [[LOOP_6_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[IterResult:%.+]] = krnl.iterate([[LOOP_6_]]#2) with () iter_args([[IterArg:%.+]] = [[CST_0_]]) -> (i32){ -// CHECK: [[VAR_9_4_:%.+]] = krnl.get_induction_var_value([[LOOP_6_]]#2) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_10_5_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_7_6_]]#0, [[VAR_9_4_]]{{.}} : memref<16x32xi32> -// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_9_4_]], [[VAR_7_6_]]#1] : memref<32x64xi32> -// CHECK: [[VAR_13_:%.+]] = arith.muli [[VAR_10_5_]], [[LOAD_RES_5_MEM_]] : i32 -// CHECK: [[VAR_14_:%.+]] = arith.addi [[IterArg]], [[VAR_13_]] : i32 -// CHECK: krnl.yield [[VAR_14_]] : i32 +// CHECK-DAG: [[LOOP_4_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_4_]]#0, [[LOOP_4_]]#1) with ([[LOOP_4_]]#0 -> [[I_8_:%.+]] = 0 to 16, [[LOOP_4_]]#1 -> [[I_9_:%.+]] = 0 to 64, [[LOOP_4_]]#2 -> [[I_10_:%.+]] = 0 to 32){ +// CHECK-DAG: [[VAR_11_4_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_4_]]#0, [[LOOP_4_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.iterate([[LOOP_4_]]#2) with () iter_args([[VAR_arg7_:%.+]] = [[CST_0_]]) -> (i32){ +// CHECK-DAG: [[VAR_13_3_:%.+]] = krnl.get_induction_var_value([[LOOP_4_]]#2) : (!krnl.loop) -> index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_14_3_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_11_4_]]#0, [[VAR_13_3_]]{{.}} : memref<16x32xi32> +// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_13_3_]], [[VAR_11_4_]]#1] : memref<32x64xi32> +// CHECK: [[VAR_16_:%.+]] = arith.muli [[VAR_14_3_]], [[LOAD_RES_5_MEM_]] : i32 +// CHECK: [[VAR_17_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_16_]] : i32 +// CHECK: krnl.yield [[VAR_17_]] : i32 // CHECK: } -// CHECK: krnl.store [[IterResult]], [[RES_6_]]{{.}}[[VAR_7_6_]]#0, [[VAR_7_6_]]#1] : memref<16x64xi32> +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_1_1_]], [[RES_6_]]{{.}}[[VAR_11_4_]]#0, [[VAR_11_4_]]#1] : memref<16x64xi32> // CHECK: } // CHECK: return [[RES_6_]] : memref<16x64xi32> // CHECK: } @@ -100,71 +93,68 @@ func.func @test_matmulinteger_per_row_a(%arg0: tensor<16x32xui8>, %arg1: tensor< // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x32xi32> // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 32){ -// CHECK: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref<16x32xui8> -// CHECK: [[VAR_9_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 -// CHECK: [[VAR_10_:%.+]] = arith.extui [[VAR_9_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref<16x32xi32> +// CHECK: [[VAR_9_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_9_]]#0, [[VAR_9_]]#1] : memref<16x32xui8> +// CHECK: [[VAR_11_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK: [[VAR_12_:%.+]] = arith.extui [[VAR_11_]] : i8 to i32 +// CHECK: krnl.store [[VAR_12_]], [[RES_]]{{.}}[[VAR_9_]]#0, [[VAR_9_]]#1] : memref<16x32xi32> // CHECK: } // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<16xi32> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 16){ -// CHECK: [[VAR_7_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_7_1_]]{{.}} : memref<16xui8> -// CHECK: [[VAR_9_1_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_]] : ui8 to i8 -// CHECK: [[VAR_10_1_:%.+]] = arith.extui [[VAR_9_1_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_1_]], [[RES_1_]]{{.}}[[VAR_7_1_]]{{.}} : memref<16xi32> +// CHECK: [[VAR_9_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_9_1_]]{{.}} : memref<16xui8> +// CHECK: [[VAR_11_1_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_]] : ui8 to i8 +// CHECK: [[VAR_12_1_:%.+]] = arith.extui [[VAR_11_1_]] : i8 to i32 +// CHECK: krnl.store [[VAR_12_1_]], [[RES_1_]]{{.}}[[VAR_9_1_]]{{.}} : memref<16xi32> // CHECK: } // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_1_]] to offset: [0], sizes: [16, 1], strides: [1, 1] : memref<16xi32> to memref<16x1xi32> // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<16x32xi32> // CHECK-DAG: [[LOOP_2_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1) with ([[LOOP_2_]]#0 -> [[I_3_:%.+]] = 0 to 16, [[LOOP_2_]]#1 -> [[I_4_:%.+]] = 0 to 32){ -// CHECK: [[VAR_7_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_7_2_]]#0, [[VAR_7_2_]]#1] : memref<16x32xi32> -// CHECK-DAG: [[VAR_9_1_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_7_2_]]#0, [[CST_0_1_]]{{.}} : memref<16x1xi32> -// CHECK: [[VAR_10_2_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_]], [[VAR_9_1_]] : i32 -// CHECK: krnl.store [[VAR_10_2_]], [[RES_2_]]{{.}}[[VAR_7_2_]]#0, [[VAR_7_2_]]#1] : memref<16x32xi32> +// CHECK: [[VAR_9_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_9_2_]]#0, [[VAR_9_2_]]#1] : memref<16x32xi32> +// CHECK-DAG: [[VAR_11_1_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_9_2_]]#0, [[CST_0_1_]]{{.}} : memref<16x1xi32> +// CHECK: [[VAR_12_2_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_]], [[VAR_11_1_]] : i32 +// CHECK: krnl.store [[VAR_12_2_]], [[RES_2_]]{{.}}[[VAR_9_2_]]#0, [[VAR_9_2_]]#1] : memref<16x32xi32> // CHECK: } // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<32x64xi32> // CHECK-DAG: [[LOOP_3_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_3_]]#0, [[LOOP_3_]]#1) with ([[LOOP_3_]]#0 -> [[I_5_:%.+]] = 0 to 32, [[LOOP_3_]]#1 -> [[I_6_:%.+]] = 0 to 64){ -// CHECK: [[VAR_7_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_3_]]#0, [[LOOP_3_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_7_3_]]#0, [[VAR_7_3_]]#1] : memref<32x64xui8> -// CHECK: [[VAR_9_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_1_]] : ui8 to i8 -// CHECK: [[VAR_10_3_:%.+]] = arith.extui [[VAR_9_2_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_3_]], [[RES_3_]]{{.}}[[VAR_7_3_]]#0, [[VAR_7_3_]]#1] : memref<32x64xi32> +// CHECK: [[VAR_9_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_3_]]#0, [[LOOP_3_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_9_3_]]#0, [[VAR_9_3_]]#1] : memref<32x64xui8> +// CHECK: [[VAR_11_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_1_]] : ui8 to i8 +// CHECK: [[VAR_12_3_:%.+]] = arith.extui [[VAR_11_2_]] : i8 to i32 +// CHECK: krnl.store [[VAR_12_3_]], [[RES_3_]]{{.}}[[VAR_9_3_]]#0, [[VAR_9_3_]]#1] : memref<32x64xi32> // CHECK: } // CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<1xi32> -// CHECK-DAG: [[LOOP_4_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_4_]]) with ([[LOOP_4_]] -> [[I_7_:%.+]] = 0 to 1){ -// CHECK: [[VAR_7_4_:%.+]] = krnl.get_induction_var_value([[LOOP_4_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[PARAM_3_]]{{.}}[[VAR_7_4_]]{{.}} : memref<1xui8> -// CHECK: [[VAR_9_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_1_]] : ui8 to i8 -// CHECK: [[VAR_10_4_:%.+]] = arith.extui [[VAR_9_3_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_4_]], [[RES_4_]]{{.}}[[VAR_7_4_]]{{.}} : memref<1xi32> -// CHECK: } +// CHECK-DAG: [[LOAD_PARAM_3_MEM_:%.+]] = krnl.load [[PARAM_3_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xui8> +// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_3_MEM_]] : ui8 to i8 +// CHECK: [[VAR_6_:%.+]] = arith.extui [[VAR_5_]] : i8 to i32 +// CHECK: krnl.store [[VAR_6_]], [[RES_4_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<32x64xi32> -// CHECK-DAG: [[LOOP_5_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_5_]]#0, [[LOOP_5_]]#1) with ([[LOOP_5_]]#0 -> [[I_8_:%.+]] = 0 to 32, [[LOOP_5_]]#1 -> [[I_9_:%.+]] = 0 to 64){ -// CHECK: [[VAR_7_5_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_5_]]#0, [[LOOP_5_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_1_1_:%.+]] = krnl.load [[RES_3_]]{{.}}[[VAR_7_5_]]#0, [[VAR_7_5_]]#1] : memref<32x64xi32> -// CHECK-DAG: [[VAR_9_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> -// CHECK: [[VAR_10_5_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_1_1_]], [[VAR_9_3_]] : i32 -// CHECK: krnl.store [[VAR_10_5_]], [[RES_5_]]{{.}}[[VAR_7_5_]]#0, [[VAR_7_5_]]#1] : memref<32x64xi32> +// CHECK-DAG: [[LOOP_4_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_4_]]#0, [[LOOP_4_]]#1) with ([[LOOP_4_]]#0 -> [[I_7_:%.+]] = 0 to 32, [[LOOP_4_]]#1 -> [[I_8_:%.+]] = 0 to 64){ +// CHECK: [[VAR_9_4_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_4_]]#0, [[LOOP_4_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_3_]]{{.}}[[VAR_9_4_]]#0, [[VAR_9_4_]]#1] : memref<32x64xi32> +// CHECK-DAG: [[VAR_11_2_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> +// CHECK: [[VAR_12_4_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_1_]], [[VAR_11_2_]] : i32 +// CHECK: krnl.store [[VAR_12_4_]], [[RES_5_]]{{.}}[[VAR_9_4_]]#0, [[VAR_9_4_]]#1] : memref<32x64xi32> // CHECK: } // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<16x64xi32> -// CHECK-DAG: [[LOOP_6_:%.+]]:3 = krnl.define_loops 3 -// CHECK: krnl.iterate([[LOOP_6_]]#0, [[LOOP_6_]]#1) with ([[LOOP_6_]]#0 -> [[I_10_:%.+]] = 0 to 16, [[LOOP_6_]]#1 -> [[I_11_:%.+]] = 0 to 64, [[LOOP_6_]]#2 -> [[I_12_:%.+]] = 0 to 32){ -// CHECK-DAG: [[VAR_7_6_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_6_]]#0, [[LOOP_6_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[IterResult:%.+]] = krnl.iterate([[LOOP_6_]]#2) with () iter_args([[IterArg:%.+]] = [[CST_0_]]) -> (i32){ -// CHECK: [[VAR_9_4_:%.+]] = krnl.get_induction_var_value([[LOOP_6_]]#2) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_10_5_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_7_6_]]#0, [[VAR_9_4_]]{{.}} : memref<16x32xi32> -// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_9_4_]], [[VAR_7_6_]]#1] : memref<32x64xi32> -// CHECK: [[VAR_13_:%.+]] = arith.muli [[VAR_10_5_]], [[LOAD_RES_5_MEM_]] : i32 -// CHECK: [[VAR_14_:%.+]] = arith.addi [[IterArg]], [[VAR_13_]] : i32 -// CHECK: krnl.yield [[VAR_14_]] : i32 +// CHECK-DAG: [[LOOP_5_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_5_]]#0, [[LOOP_5_]]#1) with ([[LOOP_5_]]#0 -> [[I_9_:%.+]] = 0 to 16, [[LOOP_5_]]#1 -> [[I_10_:%.+]] = 0 to 64, [[LOOP_5_]]#2 -> [[I_11_:%.+]] = 0 to 32){ +// CHECK-DAG: [[VAR_9_5_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_5_]]#0, [[LOOP_5_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_1_1_:%.+]] = krnl.iterate([[LOOP_5_]]#2) with () iter_args([[VAR_arg7_:%.+]] = [[CST_0_]]) -> (i32){ +// CHECK-DAG: [[VAR_11_3_:%.+]] = krnl.get_induction_var_value([[LOOP_5_]]#2) : (!krnl.loop) -> index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_12_4_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_9_5_]]#0, [[VAR_11_3_]]{{.}} : memref<16x32xi32> +// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_11_3_]], [[VAR_9_5_]]#1] : memref<32x64xi32> +// CHECK: [[VAR_14_:%.+]] = arith.muli [[VAR_12_4_]], [[LOAD_RES_5_MEM_]] : i32 +// CHECK: [[VAR_15_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_14_]] : i32 +// CHECK: krnl.yield [[VAR_15_]] : i32 // CHECK: } -// CHECK: krnl.store [[IterResult]], [[RES_6_]]{{.}}[[VAR_7_6_]]#0, [[VAR_7_6_]]#1] : memref<16x64xi32> +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_1_1_1_]], [[RES_6_]]{{.}}[[VAR_9_5_]]#0, [[VAR_9_5_]]#1] : memref<16x64xi32> // CHECK: } // CHECK: return [[RES_6_]] : memref<16x64xi32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize_O3.mlir index 0922754334..29c76f24fc 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize_O3.mlir @@ -1,8 +1,9 @@ // RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --mcpu=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// ----- + // use --mtriple=s390x-ibm-loz --mcpu=z16 to enable SIMD as we now need a machine // can also use -march=x86-64 instead. - // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. @@ -23,83 +24,79 @@ func.func @test_matmulinteger_per_tensor(%arg0: tensor<16x32xui8>, %arg1: tensor // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x32xi32> // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 32){ -// CHECK: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref<16x32xui8> -// CHECK: [[VAR_9_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 -// CHECK: [[VAR_10_:%.+]] = arith.extui [[VAR_9_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref<16x32xi32> +// CHECK: [[VAR_9_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_9_]]#0, [[VAR_9_]]#1] : memref<16x32xui8> +// CHECK: [[VAR_11_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK: [[VAR_12_:%.+]] = arith.extui [[VAR_11_]] : i8 to i32 +// CHECK: krnl.store [[VAR_12_]], [[RES_]]{{.}}[[VAR_9_]]#0, [[VAR_9_]]#1] : memref<16x32xi32> // CHECK: } // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1xi32> -// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 1){ -// CHECK: [[VAR_7_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_7_1_]]{{.}} : memref<1xui8> -// CHECK: [[VAR_9_1_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_]] : ui8 to i8 -// CHECK: [[VAR_10_1_:%.+]] = arith.extui [[VAR_9_1_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_1_]], [[RES_1_]]{{.}}[[VAR_7_1_]]{{.}} : memref<1xi32> -// CHECK: } +// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xui8> +// CHECK: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 +// CHECK: [[VAR_3_:%.+]] = arith.extui [[VAR_2_]] : i8 to i32 +// CHECK: krnl.store [[VAR_3_]], [[RES_1_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<16x32xi32> // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_512_]], [[RES_3_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[RES_]]([[RES_]]_2) : (memref<16x32xi32>, memref<1xindex>) -> memref<512xi32> // CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_512_]], [[RES_4_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_4_:%.+]] = memref.reshape [[RES_2_]]([[RES_4_]]) : (memref<16x32xi32>, memref<1xindex>) -> memref<512xi32> -// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_3_:%.+]] = 0 to 512){ -// CHECK: [[VAR_7_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_2_]]{{.}} : memref<512xi32>, vector<32xi32> -// CHECK-DAG: [[VAR_9_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> -// CHECK: [[VAR_10_2_:%.+]] = vector.splat [[VAR_9_1_]] : vector<32xi32> -// CHECK: [[VAR_11_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_]], [[VAR_10_2_]] : vector<32xi32> -// CHECK: vector.store [[VAR_11_]], [[VAR_reshape_4_]]{{.}}[[VAR_7_2_]]{{.}} : memref<512xi32>, vector<32xi32> +// CHECK: [[VAR_reshape_4_:%.+]] = memref.reshape [[RES_2_]]([[RES_4_]]) : (memref<16x32xi32>, memref<1xindex>) -> memref<512xi32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 512){ +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_11_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[LOAD_PARAM_0_MEM_1_]]{{.}} : memref<512xi32>, vector<32xi32> +// CHECK-DAG: [[VAR_12_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> +// CHECK: [[VAR_13_:%.+]] = vector.splat [[VAR_12_1_]] : vector<32xi32> +// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_11_1_]], [[VAR_13_]] : vector<32xi32> +// CHECK: vector.store [[VAR_14_]], [[VAR_reshape_4_]]{{.}}[[LOAD_PARAM_0_MEM_1_]]{{.}} : memref<512xi32>, vector<32xi32> +// CHECK: } // CHECK: } // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<32x64xi32> -// CHECK-DAG: [[LOOP_3_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_3_]]#0, [[LOOP_3_]]#1) with ([[LOOP_3_]]#0 -> [[I_4_:%.+]] = 0 to 32, [[LOOP_3_]]#1 -> [[I_5_:%.+]] = 0 to 64){ -// CHECK: [[VAR_7_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_3_]]#0, [[LOOP_3_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_7_3_]]#0, [[VAR_7_3_]]#1] : memref<32x64xui8> -// CHECK: [[VAR_9_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_1_]] : ui8 to i8 -// CHECK: [[VAR_10_3_:%.+]] = arith.extui [[VAR_9_2_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_3_]], [[RES_5_]]{{.}}[[VAR_7_3_]]#0, [[VAR_7_3_]]#1] : memref<32x64xi32> +// CHECK-DAG: [[LOOP_2_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1) with ([[LOOP_2_]]#0 -> [[I_3_:%.+]] = 0 to 32, [[LOOP_2_]]#1 -> [[I_4_:%.+]] = 0 to 64){ +// CHECK: [[VAR_9_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_9_1_]]#0, [[VAR_9_1_]]#1] : memref<32x64xui8> +// CHECK: [[VAR_11_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_]] : ui8 to i8 +// CHECK: [[VAR_12_2_:%.+]] = arith.extui [[VAR_11_2_]] : i8 to i32 +// CHECK: krnl.store [[VAR_12_2_]], [[RES_5_]]{{.}}[[VAR_9_1_]]#0, [[VAR_9_1_]]#1] : memref<32x64xi32> // CHECK: } // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xi32> -// CHECK-DAG: [[LOOP_4_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_4_]]) with ([[LOOP_4_]] -> [[I_6_:%.+]] = 0 to 1){ -// CHECK: [[VAR_7_4_:%.+]] = krnl.get_induction_var_value([[LOOP_4_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[PARAM_3_]]{{.}}[[VAR_7_4_]]{{.}} : memref<1xui8> -// CHECK: [[VAR_9_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_1_]] : ui8 to i8 -// CHECK: [[VAR_10_4_:%.+]] = arith.extui [[VAR_9_3_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_4_]], [[RES_6_]]{{.}}[[VAR_7_4_]]{{.}} : memref<1xi32> -// CHECK: } +// CHECK-DAG: [[LOAD_PARAM_3_MEM_:%.+]] = krnl.load [[PARAM_3_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xui8> +// CHECK: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_3_MEM_]] : ui8 to i8 +// CHECK: [[VAR_7_:%.+]] = arith.extui [[VAR_6_]] : i8 to i32 +// CHECK: krnl.store [[VAR_7_]], [[RES_6_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<32x64xi32> // CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_2048_]], [[RES_8_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_9_:%.+]] = memref.reshape [[RES_5_]]([[RES_8_]]) : (memref<32x64xi32>, memref<1xindex>) -> memref<2048xi32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_2048_]], [[RES_9_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_11_:%.+]] = memref.reshape [[RES_7_]]([[RES_9_]]) : (memref<32x64xi32>, memref<1xindex>) -> memref<2048xi32> -// CHECK-DAG: [[LOOP_5_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_5_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_5_]] -> [[I_7_:%.+]] = 0 to 2048){ -// CHECK: [[VAR_7_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_1_1_:%.+]] = vector.load [[VAR_reshape_9_]]{{.}}[[VAR_7_5_]]{{.}} : memref<2048xi32>, vector<32xi32> -// CHECK-DAG: [[VAR_9_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> -// CHECK: [[VAR_10_5_:%.+]] = vector.splat [[VAR_9_3_]] : vector<32xi32> -// CHECK: [[VAR_11_1_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_1_1_]], [[VAR_10_5_]] : vector<32xi32> -// CHECK: vector.store [[VAR_11_1_]], [[VAR_reshape_11_]]{{.}}[[VAR_7_5_]]{{.}} : memref<2048xi32>, vector<32xi32> +// CHECK: [[VAR_reshape_11_:%.+]] = memref.reshape [[RES_7_]]([[RES_9_]]) : (memref<32x64xi32>, memref<1xindex>) -> memref<2048xi32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_3_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to 2048){ +// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_11_2_:%.+]] = vector.load [[VAR_reshape_9_]]{{.}}[[LOAD_PARAM_0_MEM_1_1_]]{{.}} : memref<2048xi32>, vector<32xi32> +// CHECK-DAG: [[VAR_12_2_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> +// CHECK: [[VAR_13_1_:%.+]] = vector.splat [[VAR_12_2_]] : vector<32xi32> +// CHECK: [[VAR_14_1_:%.+]] = arith.subi [[VAR_11_2_]], [[VAR_13_1_]] : vector<32xi32> +// CHECK: vector.store [[VAR_14_1_]], [[VAR_reshape_11_]]{{.}}[[LOAD_PARAM_0_MEM_1_1_]]{{.}} : memref<2048xi32>, vector<32xi32> +// CHECK: } // CHECK: } // CHECK: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<16x64xi32> // CHECK: krnl.memset [[RES_10_]], [[CST_0_]] : memref<16x64xi32> -// CHECK: [[LOOP_6_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_6_]]#0 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: [[BLOCK_TILE__3_:%.+]], [[BLOCK_IN__3_:%.+]] = krnl.block [[LOOP_6_]]#1 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: [[BLOCK_TILE__4_:%.+]], [[BLOCK_IN__4_:%.+]] = krnl.block [[LOOP_6_]]#2 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.permute([[BLOCK_TILE__2_]], [[BLOCK_IN__2_]], [[BLOCK_TILE__3_]], [[BLOCK_IN__3_]], [[BLOCK_TILE__4_]], [[BLOCK_IN__4_]]) [0, 3, 1, 4, 2, 5] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop -// CHECK: krnl.iterate([[BLOCK_TILE__2_]], [[BLOCK_TILE__3_]], [[BLOCK_TILE__4_]]) with ([[LOOP_6_]]#0 -> [[I_8_:%.+]] = [[CST_0_1_]] to [[CST_16_]], [[LOOP_6_]]#1 -> [[I_9_:%.+]] = [[CST_0_1_]] to [[CST_64_]], [[LOOP_6_]]#2 -> [[I_10_:%.+]] = [[CST_0_1_]] to [[CST_32_]]){ -// CHECK: [[VAR_7_6_:%.+]]:3 = krnl.get_induction_var_value([[BLOCK_TILE__2_]], [[BLOCK_TILE__3_]], [[BLOCK_TILE__4_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK: krnl.matmul [[RES_2_]]{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, [[RES_7_]]{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, [[RES_2_]]4{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, ([[BLOCK_IN__2_]], [[BLOCK_IN__3_]], [[BLOCK_IN__4_]]), ([[VAR_7_6_]]#0, [[VAR_7_6_]]#1, [[VAR_7_6_]]#2), ([[CST_16_]], [[CST_64_]], [[CST_32_]]) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 8]} : memref<16x32xi32>, memref<32x64xi32>, memref<16x64xi32>, (!krnl.loop, !krnl.loop, !krnl.loop) +// CHECK: [[LOOP_4_:%.+]]:3 = krnl.define_loops 3 +// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_4_]]#0 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__3_:%.+]], [[BLOCK_IN__3_:%.+]] = krnl.block [[LOOP_4_]]#1 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__4_:%.+]], [[BLOCK_IN__4_:%.+]] = krnl.block [[LOOP_4_]]#2 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.permute([[BLOCK_TILE__2_]], [[BLOCK_IN__2_]], [[BLOCK_TILE__2_]]_13, [[BLOCK_IN__2_]]_14, [[BLOCK_TILE__2_]]_15, [[BLOCK_IN__2_]]_16) [0, 3, 1, 4, 2, 5] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop +// CHECK: krnl.iterate([[BLOCK_TILE__2_]], [[BLOCK_TILE__2_]]_13, [[BLOCK_TILE__2_]]_15) with ([[LOOP_4_]]#0 -> [[I_6_:%.+]] = [[CST_0_1_]] to [[CST_16_]], [[LOOP_4_]]#1 -> [[I_7_:%.+]] = [[CST_0_1_]] to [[CST_64_]], [[LOOP_4_]]#2 -> [[I_8_:%.+]] = [[CST_0_1_]] to [[CST_32_]]){ +// CHECK: [[VAR_9_2_:%.+]]:3 = krnl.get_induction_var_value([[BLOCK_TILE__2_]], [[BLOCK_TILE__2_]]_13, [[BLOCK_TILE__2_]]_15) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: krnl.matmul [[RES_2_]]{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, [[RES_7_]]{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, [[RES_2_]]2{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, ([[BLOCK_IN__2_]], [[BLOCK_IN__2_]]_14, [[BLOCK_IN__2_]]_16), ([[VAR_9_2_]]#0, [[VAR_9_2_]]#1, [[VAR_9_2_]]#2), ([[CST_16_]], [[CST_64_]], [[CST_32_]]) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 8]} : memref<16x32xi32>, memref<32x64xi32>, memref<16x64xi32>, (!krnl.loop, !krnl.loop, !krnl.loop) // CHECK: } // CHECK: return [[RES_10_]] : memref<16x64xi32> // CHECK: } @@ -107,6 +104,7 @@ func.func @test_matmulinteger_per_tensor(%arg0: tensor<16x32xui8>, %arg1: tensor // ----- + func.func @test_matmulinteger_per_row_a(%arg0: tensor<16x32xui8>, %arg1: tensor<32x64xui8>, %arg2: tensor<16xui8>, %arg3: tensor<1xui8>) -> tensor<16x64xi32> { %0 = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (tensor<16x32xui8>, tensor<32x64xui8>, tensor<16xui8>, tensor<1xui8>) -> tensor<16x64xi32> return %0 : tensor<16x64xi32> @@ -123,67 +121,69 @@ func.func @test_matmulinteger_per_row_a(%arg0: tensor<16x32xui8>, %arg1: tensor< // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x32xi32> // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 32){ -// CHECK: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref<16x32xui8> -// CHECK: [[VAR_9_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 -// CHECK: [[VAR_10_:%.+]] = arith.extui [[VAR_9_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref<16x32xi32> +// CHECK: [[VAR_8_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1] : memref<16x32xui8> +// CHECK: [[VAR_10_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK: [[VAR_11_:%.+]] = arith.extui [[VAR_10_]] : i8 to i32 +// CHECK: krnl.store [[VAR_11_]], [[RES_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1] : memref<16x32xi32> // CHECK: } // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<16xi32> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 16){ -// CHECK: [[VAR_7_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_7_1_]]{{.}} : memref<16xui8> -// CHECK: [[VAR_9_1_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_]] : ui8 to i8 -// CHECK: [[VAR_10_1_:%.+]] = arith.extui [[VAR_9_1_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_1_]], [[RES_1_]]{{.}}[[VAR_7_1_]]{{.}} : memref<16xi32> +// CHECK: [[VAR_8_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_8_1_]]{{.}} : memref<16xui8> +// CHECK: [[VAR_10_1_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_]] : ui8 to i8 +// CHECK: [[VAR_11_1_:%.+]] = arith.extui [[VAR_10_1_]] : i8 to i32 +// CHECK: krnl.store [[VAR_11_1_]], [[RES_1_]]{{.}}[[VAR_8_1_]]{{.}} : memref<16xi32> // CHECK: } // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_1_]] to offset: [0], sizes: [16, 1], strides: [1, 1] : memref<16xi32> to memref<16x1xi32> // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<16x32xi32> -// CHECK-DAG: [[LOOP_2_:%.+]]:2 = krnl.define_loops 2 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]]#1 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_2_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_2_]]#0 -> [[I_3_:%.+]] = 0 to 16, [[LOOP_2_]]#1 -> [[I_4_:%.+]] = 0 to 32){ -// CHECK: [[VAR_7_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[RES_]]{{.}}[[VAR_7_2_]]#0, [[VAR_7_2_]]#1] : memref<16x32xi32>, vector<32xi32> -// CHECK-DAG: [[VAR_9_1_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_7_2_]]#0, [[CST_0_1_]]{{.}} : memref<16x1xi32> -// CHECK: [[VAR_10_2_:%.+]] = vector.splat [[VAR_9_1_]] : vector<32xi32> -// CHECK: [[VAR_11_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_]], [[VAR_10_2_]] : vector<32xi32> -// CHECK: vector.store [[VAR_11_]], [[RES_2_]]{{.}}[[VAR_7_2_]]#0, [[VAR_7_2_]]#1] : memref<16x32xi32>, vector<32xi32> +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_3_:%.+]] = 0 to 16){ +// CHECK-DAG: [[VAR_8_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOOP_3_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_3_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_3_]] -> [[I_4_:%.+]] = 0 to 32){ +// CHECK: [[VAR_10_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_11_1_:%.+]] = vector.load [[RES_]]{{.}}[[VAR_8_2_]], [[VAR_10_2_]]{{.}} : memref<16x32xi32>, vector<32xi32> +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_:%.+]] = krnl.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_2_]], [[CST_0_1_]]{{.}} : memref<16x1xi32> +// CHECK: [[VAR_13_:%.+]] = vector.splat [[LOAD_VAR_reinterpret_cast_MEM_]] : vector<32xi32> +// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_11_1_]], [[VAR_13_]] : vector<32xi32> +// CHECK: vector.store [[VAR_14_]], [[RES_2_]]{{.}}[[VAR_8_2_]], [[VAR_10_2_]]{{.}} : memref<16x32xi32>, vector<32xi32> +// CHECK: } // CHECK: } // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<32x64xi32> -// CHECK-DAG: [[LOOP_3_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_3_]]#0, [[LOOP_3_]]#1) with ([[LOOP_3_]]#0 -> [[I_5_:%.+]] = 0 to 32, [[LOOP_3_]]#1 -> [[I_6_:%.+]] = 0 to 64){ -// CHECK: [[VAR_7_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_3_]]#0, [[LOOP_3_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_7_3_]]#0, [[VAR_7_3_]]#1] : memref<32x64xui8> -// CHECK: [[VAR_9_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_1_]] : ui8 to i8 -// CHECK: [[VAR_10_3_:%.+]] = arith.extui [[VAR_9_2_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_3_]], [[RES_3_]]{{.}}[[VAR_7_3_]]#0, [[VAR_7_3_]]#1] : memref<32x64xi32> +// CHECK-DAG: [[LOOP_4_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_4_]]#0, [[LOOP_4_]]#1) with ([[LOOP_4_]]#0 -> [[I_5_:%.+]] = 0 to 32, [[LOOP_4_]]#1 -> [[I_6_:%.+]] = 0 to 64){ +// CHECK: [[VAR_8_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_4_]]#0, [[LOOP_4_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOOP_3_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1] : memref<32x64xui8> +// CHECK: [[VAR_10_3_:%.+]] = builtin.unrealized_conversion_cast [[LOOP_3_]] : ui8 to i8 +// CHECK: [[VAR_11_2_:%.+]] = arith.extui [[VAR_10_3_]] : i8 to i32 +// CHECK: krnl.store [[VAR_11_2_]], [[RES_3_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1] : memref<32x64xi32> // CHECK: } // CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<1xi32> -// CHECK-DAG: [[LOOP_4_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_4_]]) with ([[LOOP_4_]] -> [[I_7_:%.+]] = 0 to 1){ -// CHECK: [[VAR_7_4_:%.+]] = krnl.get_induction_var_value([[LOOP_4_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[PARAM_3_]]{{.}}[[VAR_7_4_]]{{.}} : memref<1xui8> -// CHECK: [[VAR_9_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_1_1_]] : ui8 to i8 -// CHECK: [[VAR_10_4_:%.+]] = arith.extui [[VAR_9_3_]] : i8 to i32 -// CHECK: krnl.store [[VAR_10_4_]], [[RES_4_]]{{.}}[[VAR_7_4_]]{{.}} : memref<1xi32> -// CHECK: } +// CHECK-DAG: [[LOAD_PARAM_3_MEM_:%.+]] = krnl.load [[PARAM_3_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xui8> +// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_3_MEM_]] : ui8 to i8 +// CHECK: [[VAR_6_:%.+]] = arith.extui [[VAR_5_]] : i8 to i32 +// CHECK: krnl.store [[VAR_6_]], [[RES_4_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<32x64xi32> // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_2048_]], [[RES_6_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[RES_3_]]([[RES_6_]]) : (memref<32x64xi32>, memref<1xindex>) -> memref<2048xi32> // CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_2048_]], [[RES_7_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_5_]]([[RES_7_]]) : (memref<32x64xi32>, memref<1xindex>) -> memref<2048xi32> -// CHECK-DAG: [[LOOP_5_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_5_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_5_]] -> [[I_8_:%.+]] = 0 to 2048){ -// CHECK: [[VAR_7_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_1_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_5_]]{{.}} : memref<2048xi32>, vector<32xi32> -// CHECK-DAG: [[VAR_9_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> -// CHECK: [[VAR_10_5_:%.+]] = vector.splat [[VAR_9_3_]] : vector<32xi32> -// CHECK: [[VAR_11_1_:%.+]] = arith.subi [[LOAD_PARAM_0_MEM_1_1_1_]], [[VAR_10_5_]] : vector<32xi32> -// CHECK: vector.store [[VAR_11_1_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_5_]]{{.}} : memref<2048xi32>, vector<32xi32> +// CHECK: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_5_]]([[RES_7_]]) : (memref<32x64xi32>, memref<1xindex>) -> memref<2048xi32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_5_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_5_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_5_]] -> [[I_7_:%.+]] = 0 to 2048){ +// CHECK: [[LOOP_3_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[LOOP_3_1_]]{{.}} : memref<2048xi32>, vector<32xi32> +// CHECK-DAG: [[VAR_11_2_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = vector.splat [[VAR_11_2_]] : vector<32xi32> +// CHECK: [[VAR_13_1_:%.+]] = arith.subi [[VAR_10_3_]], [[LOAD_VAR_reinterpret_cast_MEM_1_]] : vector<32xi32> +// CHECK: vector.store [[VAR_13_1_]], [[VAR_reshape_7_]]{{.}}[[LOOP_3_1_]]{{.}} : memref<2048xi32>, vector<32xi32> +// CHECK: } // CHECK: } // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<16x64xi32> // CHECK: krnl.memset [[RES_8_]], [[CST_0_]] : memref<16x64xi32> @@ -191,11 +191,12 @@ func.func @test_matmulinteger_per_row_a(%arg0: tensor<16x32xui8>, %arg1: tensor< // CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_6_]]#0 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: [[BLOCK_TILE__3_:%.+]], [[BLOCK_IN__3_:%.+]] = krnl.block [[LOOP_6_]]#1 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: [[BLOCK_TILE__4_:%.+]], [[BLOCK_IN__4_:%.+]] = krnl.block [[LOOP_6_]]#2 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.permute([[BLOCK_TILE__2_]], [[BLOCK_IN__2_]], [[BLOCK_TILE__3_]], [[BLOCK_IN__3_]], [[BLOCK_TILE__4_]], [[BLOCK_IN__4_]]) [0, 3, 1, 4, 2, 5] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop -// CHECK: krnl.iterate([[BLOCK_TILE__2_]], [[BLOCK_TILE__3_]], [[BLOCK_TILE__4_]]) with ([[LOOP_6_]]#0 -> [[I_9_:%.+]] = [[CST_0_1_]] to [[CST_16_]], [[LOOP_6_]]#1 -> [[I_10_:%.+]] = [[CST_0_1_]] to [[CST_64_]], [[LOOP_6_]]#2 -> [[I_11_:%.+]] = [[CST_0_1_]] to [[CST_32_]]){ -// CHECK: [[VAR_7_6_:%.+]]:3 = krnl.get_induction_var_value([[BLOCK_TILE__2_]], [[BLOCK_TILE__3_]], [[BLOCK_TILE__4_]]) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK: krnl.matmul [[RES_2_]]{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, [[RES_5_]]{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, [[RES_2_]]0{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, ([[BLOCK_IN__2_]], [[BLOCK_IN__3_]], [[BLOCK_IN__4_]]), ([[VAR_7_6_]]#0, [[VAR_7_6_]]#1, [[VAR_7_6_]]#2), ([[CST_16_]], [[CST_64_]], [[CST_32_]]) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 8]} : memref<16x32xi32>, memref<32x64xi32>, memref<16x64xi32>, (!krnl.loop, !krnl.loop, !krnl.loop) +// CHECK: krnl.permute([[BLOCK_TILE__2_]], [[BLOCK_IN__2_]], [[BLOCK_TILE__2_]]_9, [[BLOCK_IN__2_]]_10, [[BLOCK_TILE__2_]]_11, [[BLOCK_IN__2_]]_12) [0, 3, 1, 4, 2, 5] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop +// CHECK: krnl.iterate([[BLOCK_TILE__2_]], [[BLOCK_TILE__2_]]_9, [[BLOCK_TILE__2_]]_11) with ([[LOOP_6_]]#0 -> [[I_8_:%.+]] = [[CST_0_1_]] to [[CST_16_]], [[LOOP_6_]]#1 -> [[I_9_:%.+]] = [[CST_0_1_]] to [[CST_64_]], [[LOOP_6_]]#2 -> [[I_10_:%.+]] = [[CST_0_1_]] to [[CST_32_]]){ +// CHECK: [[VAR_8_4_:%.+]]:3 = krnl.get_induction_var_value([[BLOCK_TILE__2_]], [[BLOCK_TILE__2_]]_9, [[BLOCK_TILE__2_]]_11) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: krnl.matmul [[RES_2_]]{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, [[RES_5_]]{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, [[RES_8_]]{{.}}[[CST_0_1_]], [[CST_0_1_]]{{.}}, ([[BLOCK_IN__2_]], [[BLOCK_IN__2_]]_10, [[BLOCK_IN__2_]]_12), ([[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_8_4_]]#2), ([[CST_16_]], [[CST_64_]], [[CST_32_]]) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 8]} : memref<16x32xi32>, memref<32x64xi32>, memref<16x64xi32>, (!krnl.loop, !krnl.loop, !krnl.loop) // CHECK: } // CHECK: return [[RES_8_]] : memref<16x64xi32> // CHECK: } } + diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir index 0b7e3c01b4..5ef0892d8b 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir @@ -856,6 +856,102 @@ func.func private @test_reducemax_v13_bis(%arg0 : tensor<1028x256xf32>) -> tenso // ----- +func.func private @test_reducemax_v13_small(%arg0 : tensor<7x8xf32>) -> tensor<*xf32> { + %0 ="onnx.ReduceMaxV13"(%arg0) {axes=[-1], keepdims = 0 : si64} : (tensor<7x8xf32>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (-d0 + 3)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (d0 + 2)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0) -> (d0 + 3)> +// CHECK-LABEL: func.func private @test_reducemax_v13_small +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<7x8xf32>) -> memref<7xf32> { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<4xf32> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_7_:%.+]] = arith.constant 7 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<7xf32> +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 7){ +// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]) +// CHECK: [[VAR_3_:%.+]] = arith.cmpi slt, [[VAR_2_]], [[CST_0_]] : index +// CHECK: scf.if [[VAR_3_]] { +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_1_]] to [[CST_7_]] step [[CST_1_]] { +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = [[CST_0_]] to [[CST_8_]]){ +// CHECK: [[VAR_7_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[I_1_]], [[VAR_7_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_10_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> +// CHECK: vector.store [[VAR_10_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_6_:%.+]] = vector.reduction , [[LOAD_RES_1_MEM_1_]] : vector<4xf32> into f32 +// CHECK: krnl.store [[VAR_6_]], [[RES_]]{{.}}[[I_1_]]{{.}} : memref<7xf32> +// CHECK: } +// CHECK: } else { +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_3_:%.+]] = [[CST_0_]] to [[CST_8_]]){ +// CHECK: [[VAR_18_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]], [[VAR_1_]]8] : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_21_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK: vector.store [[VAR_21_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_22_]], [[VAR_18_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_25_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK: vector.store [[VAR_25_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_26_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_26_]], [[VAR_18_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_29_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_3_]]([[VAR_1_]]) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_30_]], [[VAR_18_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_5_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_33_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_33_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_RES_1_MEM_6_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_7_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_8_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_9_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_RES_1_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_10_1_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_11_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_10_]], [[VAR_10_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_12_]], [[VAR_13_]] : vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_17_:%.+]] = arith.maxnumf [[VAR_15_]], [[VAR_16_]] : vector<4xf32> +// CHECK: vector.store [[VAR_17_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<7xf32>, vector<4xf32> +// CHECK: } +// CHECK: } +// CHECK: return [[RES_]] : memref<7xf32> +// CHECK: } +} + +// ----- + + func.func private @test_reducemax_int_v13(%arg0 : tensor<128x256x768xi32>) -> tensor<*xi32> { %0 = "onnx.ReduceMaxV13"(%arg0) {axes = [-1], keepdims = 0 : si64, onnx_node_name = "ReduceMean_32"} : (tensor<128x256x768xi32>) -> tensor<*xi32> "func.return"(%0) : (tensor<*xi32>) -> () @@ -863,26 +959,26 @@ func.func private @test_reducemax_int_v13(%arg0 : tensor<128x256x768xi32>) -> te // mlir2FileCheck.py // CHECK-LABEL: func.func private @test_reducemax_int_v13 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<128x256x768xi32>) -> memref<128x256xi32> { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-2147483648> : vector<16xi32> +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-2147483648> : vector<32xi32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_768_:%.+]] = arith.constant 768 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<128x256xi32> // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 128, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){ // CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<1x16xi32> -// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x16xi32>, vector<16xi32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<1x32xi32> +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = [[CST_0_]] to [[CST_768_]]){ // CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_5_]]{{.}} : memref<128x256x768xi32>, vector<16xi32> -// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x16xi32>, vector<16xi32> -// CHECK: [[VAR_8_:%.+]] = arith.maxsi [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<16xi32> -// CHECK: vector.store [[VAR_8_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x16xi32>, vector<16xi32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_5_]]{{.}} : memref<128x256x768xi32>, vector<32xi32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> +// CHECK: [[VAR_8_:%.+]] = arith.maxsi [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<32xi32> +// CHECK: vector.store [[VAR_8_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> // CHECK: } -// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x16xi32>, vector<16xi32> -// CHECK: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_1_MEM_1_]] : vector<16xi32> into i32 +// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> +// CHECK: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_1_MEM_1_]] : vector<32xi32> into i32 // CHECK: krnl.store [[VAR_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<128x256xi32> // CHECK: } // CHECK: return [[RES_]] : memref<128x256xi32> diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir index c6d720c24c..55dbdb1942 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir @@ -10,14 +10,15 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor, tensor, tensor // mlir2FileCheck.py -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK-LABEL: func.func @test_dynamic_quantize_linear // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> (memref, memref, memref) { // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 // CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 // CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0x7F800000 : f32 // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 // CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index @@ -26,70 +27,73 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref -// CHECK: krnl.store [[CST_2_dot_550000_]], [[RES_3_]][] : memref -// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref -// CHECK: krnl.store [[CST_0_dot_000000_]], [[RES_4_]][] : memref -// CHECK: [[RES_5_:%.+]] = memref.alloc() : memref -// CHECK: krnl.memset [[RES_5_]], [[CST_0_1_]] : memref +// CHECK: krnl.memset [[RES_3_]], [[CST_0_1_]] : memref // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 -// CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ +// CHECK-DAG: [[VAR_dim_9_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ // CHECK: [[VAR_31_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_]]#0, [[VAR_31_]]#1] : memref -// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]][] : memref -// CHECK: [[VAR_34_:%.+]] = arith.maxnumf [[LOAD_RES_5_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_34_]], [[RES_5_]][] : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref +// CHECK: [[VAR_34_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_34_]], [[RES_3_]][] : memref // CHECK: } -// CHECK: [[RES_6_:%.+]] = memref.alloc() : memref -// CHECK: krnl.memset [[RES_6_]], [[CST_0_]] : memref +// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_4_]], [[CST_0_]] : memref // CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 -// CHECK-DAG: [[VAR_dim_13_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref -// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_13_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ +// CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ // CHECK: [[VAR_31_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) // CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_1_]]#0, [[VAR_31_1_]]#1] : memref -// CHECK-DAG: [[LOAD_RES_5_MEM_1_:%.+]] = krnl.load [[RES_6_]][] : memref -// CHECK: [[VAR_34_1_:%.+]] = arith.minnumf [[LOAD_RES_5_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 -// CHECK: krnl.store [[VAR_34_1_]], [[RES_6_]][] : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK: [[VAR_34_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_34_1_]], [[RES_4_]][] : memref // CHECK: } -// CHECK-DAG: [[LOAD_RES_5_MEM_2_:%.+]] = krnl.load [[RES_5_]][] : memref -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = krnl.load [[RES_6_]][] : memref -// CHECK: [[VAR_4_:%.+]] = arith.cmpf ogt, [[LOAD_RES_5_MEM_2_]], [[CST_0_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_5_:%.+]] = arith.select [[VAR_4_]], [[LOAD_RES_5_MEM_2_]], [[CST_0_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_6_:%.+]] = arith.cmpf olt, [[LOAD_RES_6_MEM_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[LOAD_RES_6_MEM_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_5_]], [[VAR_7_]] : f32 -// CHECK: [[VAR_9_:%.+]] = arith.divf [[VAR_8_]], [[CST_2_dot_550000_]] : f32 -// CHECK: krnl.store [[VAR_9_]], [[RES_1_]][] : memref -// CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_7_]], [[VAR_9_]] : f32 -// CHECK: [[VAR_11_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_10_]] : f32 -// CHECK: [[VAR_12_:%.+]] = arith.maxnumf [[VAR_11_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_13_:%.+]] = arith.minnumf [[VAR_12_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_14_:%.+]] = math.floor [[VAR_13_]] : f32 -// CHECK: [[VAR_15_:%.+]] = arith.subf [[VAR_13_]], [[VAR_14_]] : f32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.cmpf ogt, [[VAR_15_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.addf [[VAR_14_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxnumf [[LOAD_RES_4_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK: [[VAR_7_:%.+]] = arith.divf [[VAR_6_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_8_:%.+]] = arith.divf [[VAR_5_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.maxnumf [[VAR_9_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.minnumf [[VAR_10_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_12_:%.+]] = math.floor [[VAR_11_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.subf [[VAR_11_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.cmpf ogt, [[VAR_13_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addf [[VAR_12_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_16_]], [[VAR_17_]], [[VAR_14_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.mulf [[VAR_14_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_20_:%.+]] = math.floor [[VAR_19_]] : f32 -// CHECK: [[VAR_21_:%.+]] = arith.mulf [[VAR_20_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.subf [[VAR_14_]], [[VAR_21_]] : f32 -// CHECK-DAG: [[VAR_23_:%.+]] = arith.cmpf oeq, [[VAR_22_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.addf [[VAR_14_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[VAR_15_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.mulf [[VAR_12_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_18_:%.+]] = math.floor [[VAR_17_]] : f32 +// CHECK: [[VAR_19_:%.+]] = arith.mulf [[VAR_18_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.subf [[VAR_12_]], [[VAR_19_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf oeq, [[VAR_20_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[VAR_12_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_25_:%.+]] = arith.select [[VAR_23_]], [[VAR_24_]], [[VAR_14_]] : f32 -// CHECK-DAG: [[VAR_26_:%.+]] = arith.cmpf oeq, [[VAR_15_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_27_:%.+]] = arith.select [[VAR_26_]], [[VAR_25_]], [[VAR_18_]] : f32 -// CHECK: [[VAR_28_:%.+]] = arith.fptoui [[VAR_27_]] : f32 to i8 -// CHECK: [[VAR_29_:%.+]] = builtin.unrealized_conversion_cast [[VAR_28_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_29_]], [[RES_2_]][] : memref -// CHECK: [[LOOP_2_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1) with ([[LOOP_2_]]#0 -> [[I_4_:%.+]] = 0 to [[MAP_0_]]([[VAR_dim_]]), [[LOOP_2_]]#1 -> [[I_5_:%.+]] = 0 to 2){ -// CHECK: [[VAR_31_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD_PARAM_0_MEM_2_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_2_]]#0, [[VAR_31_2_]]#1] : memref -// CHECK: [[LOAD_RES_5_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_2_]], [[VAR_9_]] : f32 -// CHECK: [[VAR_34_2_:%.+]] = math.floor [[LOAD_RES_5_MEM_1_]] : f32 -// CHECK: [[VAR_35_:%.+]] = arith.subf [[LOAD_RES_5_MEM_1_]], [[VAR_34_2_]] : f32 +// CHECK-DAG: [[VAR_23_:%.+]] = arith.select [[VAR_21_]], [[VAR_22_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.select [[VAR_24_]], [[VAR_23_]], [[VAR_16_]] : f32 +// CHECK: [[VAR_26_:%.+]] = arith.fptoui [[VAR_25_]] : f32 to i8 +// CHECK: [[VAR_27_:%.+]] = builtin.unrealized_conversion_cast [[VAR_26_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_7_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_27_]], [[RES_2_]][] : memref +// CHECK-DAG: [[VAR_28_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_28_]], [[RES_5_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_5_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[VAR_29_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_29_]], [[RES_6_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ +// CHECK: [[VAR_31_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_31_2_]]{{.}} : memref +// CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_34_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_35_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_34_2_]] : f32 // CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpf ogt, [[VAR_35_]], [[CST_5_dot_000000_]] : f32 // CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs @@ -104,12 +108,12 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor +// CHECK: krnl.store [[VAR_52_]], [[VAR_reshape_14_]]{{.}}[[VAR_31_2_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir new file mode 100644 index 0000000000..d7180c3e2c --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir @@ -0,0 +1,408 @@ +// RUN: onnx-mlir-opt -O3 -mcpu=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +// Adding canonicalize is important here as this is the only way to check the values of the map, +// which are otherwise before the function, and thus are hard to test. + +// ----- + + +func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> (tensor<256x16xui8>, tensor, tensor) { + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<256x16xf32>) -> (tensor<256x16xui8>, tensor, tensor) + return %y, %y_scale, %y_zero_point: tensor<256x16xui8>, tensor, tensor + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_only +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x16xf32>) -> (memref<256x16xui8>, memref, memref) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_4096_:%.+]] = arith.constant 4096 : index +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<256x16xui8> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4096_]], [[RES_3_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_3_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> +// CHECK: vector.store [[VAR_cst_5_]], [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> +// CHECK: vector.store [[VAR_cst_4_]], [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = [[CST_0_]] to [[CST_4096_]]){ +// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = vector.load [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_5_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_35_]], [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[LOAD_RES_7_MEM_:%.+]] = vector.load [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[VAR_37_:%.+]] = arith.maxnumf [[LOAD_RES_7_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOAD_RES_5_MEM_1_:%.+]] = vector.load [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = vector.reduction , [[LOAD_RES_5_MEM_1_]] : vector<32xf32> into f32 +// CHECK-DAG: [[LOAD_RES_7_MEM_1_:%.+]] = vector.load [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_7_MEM_1_]] : vector<32xf32> into f32 +// CHECK: krnl.store [[VAR_2_]], [[RES_4_]][] : memref +// CHECK: krnl.store [[VAR_4_]], [[RES_6_]][] : memref +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = krnl.load [[RES_6_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_9_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.divf [[VAR_8_]], [[VAR_10_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.minnumf [[VAR_13_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_14_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_]], [[VAR_22_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 +// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i8 +// CHECK: [[VAR_30_:%.+]] = builtin.unrealized_conversion_cast [[VAR_29_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_30_]], [[RES_2_]][] : memref +// CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4096_]], [[RES_8_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> +// CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4096_]], [[RES_9_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 4096){ +// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_5_MEM_2_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> +// CHECK: [[VAR_35_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_1_]], [[LOAD_RES_5_MEM_2_]] : vector<8xf32> +// CHECK: [[LOAD_RES_7_MEM_2_:%.+]] = math.floor [[VAR_35_1_]] : vector<8xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[VAR_35_1_]], [[LOAD_RES_7_MEM_2_]] : vector<8xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_]], [[VAR_39_]], [[LOAD_RES_7_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> +// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> +// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_7_MEM_2_]], [[VAR_43_]] : vector<8xf32> +// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_7_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> +// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> +// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi8> +// CHECK: [[VAR_55_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_55_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xui8>, vector<8xui8> +// CHECK: } +// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<256x16xui8>, memref, memref +// CHECK: } +} + +// ----- + + +func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32>) -> (tensor<255x17xui8>, tensor, tensor) { + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<255x17xf32>) -> (tensor<255x17xui8>, tensor, tensor) + return %y, %y_scale, %y_zero_point: tensor<255x17xui8>, tensor, tensor + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_and_scalar +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<255x17xf32>) -> (memref<255x17xui8>, memref, memref) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[CST_4335_:%.+]] = arith.constant 4335 : index +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<255x17xui8> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4335_]], [[RES_3_]][0] : memref<1xindex> +// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_4_]], [[CST_0_1_]] : memref +// CHECK: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 255, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 17){ +// CHECK: [[VAR_30_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_30_]]#0, [[VAR_30_]]#1] : memref<255x17xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK: [[VAR_33_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_33_]], [[RES_4_]][] : memref +// CHECK: } +// CHECK: [[RES_5_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_5_]], [[CST_0_]] : memref +// CHECK: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to 255, [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 17){ +// CHECK: [[VAR_30_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_30_1_]]#0, [[VAR_30_1_]]#1] : memref<255x17xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = krnl.load [[RES_5_]][] : memref +// CHECK: [[VAR_33_1_:%.+]] = arith.maxnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_33_1_]], [[RES_5_]][] : memref +// CHECK: } +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxnumf [[LOAD_RES_5_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK: [[VAR_7_:%.+]] = arith.divf [[VAR_6_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_8_:%.+]] = arith.divf [[VAR_5_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.maxnumf [[VAR_9_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.minnumf [[VAR_10_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_12_:%.+]] = math.floor [[VAR_11_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.subf [[VAR_11_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.cmpf ogt, [[VAR_13_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addf [[VAR_12_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[VAR_15_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.mulf [[VAR_12_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_18_:%.+]] = math.floor [[VAR_17_]] : f32 +// CHECK: [[VAR_19_:%.+]] = arith.mulf [[VAR_18_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.subf [[VAR_12_]], [[VAR_19_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf oeq, [[VAR_20_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[VAR_12_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_23_:%.+]] = arith.select [[VAR_21_]], [[VAR_22_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.select [[VAR_24_]], [[VAR_23_]], [[VAR_16_]] : f32 +// CHECK: [[VAR_26_:%.+]] = arith.fptoui [[VAR_25_]] : f32 to i8 +// CHECK: [[VAR_27_:%.+]] = builtin.unrealized_conversion_cast [[VAR_26_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_7_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_27_]], [[RES_2_]][] : memref +// CHECK: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4335_]], [[RES_6_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_6_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4335_]], [[RES_7_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_18_:%.+]] = memref.reshape [[RES_]]([[RES_]]_17) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4328){ +// CHECK: [[VAR_30_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_30_2_]]{{.}} : memref<4335xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.splat [[VAR_7_]] : vector<8xf32> +// CHECK: [[VAR_33_2_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[LOAD_RES_4_MEM_1_]] : vector<8xf32> +// CHECK: [[VAR_34_:%.+]] = math.floor [[VAR_33_2_]] : vector<8xf32> +// CHECK: [[VAR_35_:%.+]] = arith.subf [[VAR_33_2_]], [[VAR_34_]] : vector<8xf32> +// CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpf ogt, [[VAR_35_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_34_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_38_:%.+]] = arith.select [[VAR_36_]], [[VAR_37_]], [[VAR_34_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.mulf [[VAR_34_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK: [[VAR_40_:%.+]] = math.floor [[VAR_39_]] : vector<8xf32> +// CHECK: [[VAR_41_:%.+]] = arith.mulf [[VAR_40_]], [[VAR_cst_2_]] : vector<8xf32> +// CHECK: [[VAR_42_:%.+]] = arith.subf [[VAR_34_]], [[VAR_41_]] : vector<8xf32> +// CHECK-DAG: [[VAR_43_:%.+]] = arith.cmpf oeq, [[VAR_42_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_44_:%.+]] = arith.addf [[VAR_34_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_45_:%.+]] = arith.select [[VAR_43_]], [[VAR_44_]], [[VAR_34_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_35_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_46_]], [[VAR_45_]], [[VAR_38_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = vector.splat [[VAR_25_]] : vector<8xf32> +// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_47_]], [[VAR_48_]] : vector<8xf32> +// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_52_:%.+]] = arith.fptoui [[VAR_51_]] : vector<8xf32> to vector<8xi8> +// CHECK: [[VAR_53_:%.+]] = builtin.unrealized_conversion_cast [[VAR_52_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_53_]], [[VAR_reshape_18_]]{{.}}[[VAR_30_2_]]{{.}} : memref<4335xui8>, vector<8xui8> +// CHECK: } +// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4328 to 4335){ +// CHECK: [[VAR_30_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_30_3_]]{{.}} : memref<4335xf32> +// CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_1_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_33_3_:%.+]] = math.floor [[LOAD_RES_4_MEM_1_1_]] : f32 +// CHECK: [[VAR_34_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_1_]], [[VAR_33_3_]] : f32 +// CHECK-DAG: [[VAR_35_1_:%.+]] = arith.cmpf ogt, [[VAR_34_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_36_1_:%.+]] = arith.addf [[VAR_33_3_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_37_1_:%.+]] = arith.select [[VAR_35_1_]], [[VAR_36_1_]], [[VAR_33_3_]] : f32 +// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.mulf [[VAR_33_3_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_39_1_:%.+]] = math.floor [[VAR_38_1_]] : f32 +// CHECK: [[VAR_40_1_:%.+]] = arith.mulf [[VAR_39_1_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_41_1_:%.+]] = arith.subf [[VAR_33_3_]], [[VAR_40_1_]] : f32 +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.cmpf oeq, [[VAR_41_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.addf [[VAR_33_3_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_44_1_:%.+]] = arith.select [[VAR_42_1_]], [[VAR_43_1_]], [[VAR_33_3_]] : f32 +// CHECK-DAG: [[VAR_45_1_:%.+]] = arith.cmpf oeq, [[VAR_34_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_46_1_:%.+]] = arith.select [[VAR_45_1_]], [[VAR_44_1_]], [[VAR_37_1_]] : f32 +// CHECK: [[VAR_47_1_:%.+]] = arith.addf [[VAR_46_1_]], [[VAR_25_]] : f32 +// CHECK: [[VAR_48_1_:%.+]] = arith.maxnumf [[VAR_47_1_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_49_1_:%.+]] = arith.minnumf [[VAR_48_1_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_50_1_:%.+]] = arith.fptoui [[VAR_49_1_]] : f32 to i8 +// CHECK: [[VAR_51_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_50_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_51_1_]], [[VAR_reshape_18_]]{{.}}[[VAR_30_3_]]{{.}} : memref<4335xui8> +// CHECK: } +// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<255x17xui8>, memref, memref +// CHECK: } +} + +// ----- + + +func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32>) -> (tensor<1x8xui8>, tensor, tensor) { + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<1x8xf32>) -> (tensor<1x8xui8>, tensor, tensor) + return %y, %y_scale, %y_zero_point: tensor<1x8xui8>, tensor, tensor + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_dynamic_quantize_linear_reduced_simd_only +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x8xf32>) -> (memref<1x8xui8>, memref, memref) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<8xf32> +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x8xui8> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_8_]], [[RES_3_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_3_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> +// CHECK: vector.store [[VAR_cst_5_]], [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> +// CHECK: vector.store [[VAR_cst_4_]], [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = [[CST_0_]] to [[CST_8_]]){ +// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = vector.load [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_5_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK: vector.store [[VAR_35_]], [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[LOAD_RES_7_MEM_:%.+]] = vector.load [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_37_:%.+]] = arith.maxnumf [[LOAD_RES_7_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: } +// CHECK: [[LOAD_RES_5_MEM_1_:%.+]] = vector.load [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = vector.reduction , [[LOAD_RES_5_MEM_1_]] : vector<8xf32> into f32 +// CHECK-DAG: [[LOAD_RES_7_MEM_1_:%.+]] = vector.load [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_7_MEM_1_]] : vector<8xf32> into f32 +// CHECK: krnl.store [[VAR_2_]], [[RES_4_]][] : memref +// CHECK: krnl.store [[VAR_4_]], [[RES_6_]][] : memref +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = krnl.load [[RES_6_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_9_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.divf [[VAR_8_]], [[VAR_10_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.minnumf [[VAR_13_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_14_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_]], [[VAR_22_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 +// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i8 +// CHECK: [[VAR_30_:%.+]] = builtin.unrealized_conversion_cast [[VAR_29_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_30_]], [[RES_2_]][] : memref +// CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_8_]], [[RES_8_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> +// CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_8_]], [[RES_9_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ +// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_5_MEM_2_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> +// CHECK: [[VAR_35_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_1_]], [[LOAD_RES_5_MEM_2_]] : vector<8xf32> +// CHECK: [[LOAD_RES_7_MEM_2_:%.+]] = math.floor [[VAR_35_1_]] : vector<8xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[VAR_35_1_]], [[LOAD_RES_7_MEM_2_]] : vector<8xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_]], [[VAR_39_]], [[LOAD_RES_7_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> +// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> +// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_7_MEM_2_]], [[VAR_43_]] : vector<8xf32> +// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_7_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> +// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> +// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi8> +// CHECK: [[VAR_55_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_55_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xui8>, vector<8xui8> +// CHECK: } +// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref +// CHECK: } +} + diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir index 24b31ec16e..65c77c702d 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir @@ -3,10 +3,14 @@ // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. +// ----- + + func.func @test_quantize_linear(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xui8> { %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<6xf32>, tensor, tensor) -> tensor<6xui8> return %0 : tensor<6xui8> +// mlir2FileCheck.py // CHECK-LABEL: func.func @test_quantize_linear // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<6xf32>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<6xui8> { // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 @@ -41,14 +45,13 @@ func.func @test_quantize_linear(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: // CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpf oeq, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_21_:%.+]] = arith.select [[VAR_20_]], [[VAR_19_]], [[VAR_12_]] : f32 // CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_3_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.cmpf olt, [[VAR_22_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.select [[VAR_23_]], [[CST_0_dot_000000_]], [[VAR_22_]] : f32 -// CHECK: [[VAR_25_:%.+]] = arith.cmpf olt, [[VAR_24_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_26_:%.+]] = arith.select [[VAR_25_]], [[VAR_24_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_27_:%.+]] = arith.fptoui [[VAR_26_]] : f32 to i8 -// CHECK: [[VAR_28_:%.+]] = builtin.unrealized_conversion_cast [[VAR_27_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_28_]], [[RES_]]{{.}}[[VAR_5_]]{{.}} : memref<6xui8> +// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[VAR_22_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.minnumf [[VAR_23_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.fptoui [[VAR_24_]] : f32 to i8 +// CHECK: [[VAR_26_:%.+]] = builtin.unrealized_conversion_cast [[VAR_25_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_26_]], [[RES_]]{{.}}[[VAR_5_]]{{.}} : memref<6xui8> // CHECK: } // CHECK: return [[RES_]] : memref<6xui8> // CHECK: } } + diff --git a/test/mlir/conversion/onnx_to_stablehlo/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_stablehlo/Math/Elementwise.mlir index 733bd440e7..e7b5fe9f1b 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Math/Elementwise.mlir @@ -59,7 +59,7 @@ func.func @test_elu(%arg0 : tensor<20x40xf32>) -> tensor<20x40xf32> { // CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.exponential [[PARAM_0_]] : tensor<20x40xf32> // CHECK: [[VAR_4_:%.+]] = stablehlo.subtract [[VAR_3_]], [[VAR_2_]] : tensor<20x40xf32> // CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.multiply [[VAR_1_]], [[VAR_4_]] : tensor<20x40xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.compare GE, [[PARAM_0_]], [[VAR_0_]], NOTYPE : (tensor<20x40xf32>, tensor<20x40xf32>) -> tensor<20x40xi1> +// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.compare GE, [[PARAM_0_]], [[VAR_0_]] : (tensor<20x40xf32>, tensor<20x40xf32>) -> tensor<20x40xi1> // CHECK: [[VAR_7_:%.+]] = stablehlo.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : tensor<20x40xi1>, tensor<20x40xf32> // CHECK: return [[VAR_7_]] : tensor<20x40xf32> // CHECK: } @@ -287,7 +287,7 @@ func.func @test_less(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x5xf32>) -> tens // CHECK: [[VAR_0_:%.+]] = shape.const_shape [3, 4, 5] : tensor<3xindex> // CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_0_]], dims = [0, 1, 2] : (tensor<3x4x5xf32>, tensor<3xindex>) -> tensor<3x4x5xf32> // CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_1_]], [[VAR_0_]], dims = [0, 1, 2] : (tensor<3x4x5xf32>, tensor<3xindex>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_1_]], [[VAR_2_]], NOTYPE : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_1_]], [[VAR_2_]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> // CHECK: return [[VAR_3_]] : tensor<3x4x5xi1> // CHECK: } @@ -305,7 +305,7 @@ func.func @test_binary_elementwise_op_template_unknown_dims(%arg0: tensor, tensor<3xindex> -> tensor<3xindex> // CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_2_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor // CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_1_]], [[VAR_2_]], dims = [0, 1, 2] : (tensor<1x?x1xf32>, tensor<3xindex>) -> tensor -// CHECK: [[VAR_5_:%.+]] = stablehlo.compare LT, [[VAR_3_]], [[VAR_4_]], NOTYPE : (tensor, tensor) -> tensor +// CHECK: [[VAR_5_:%.+]] = stablehlo.compare LT, [[VAR_3_]], [[VAR_4_]] : (tensor, tensor) -> tensor // CHECK: return [[VAR_5_]] : tensor // CHECK: } @@ -323,7 +323,7 @@ func.func @test_less_unknown_dims_2(%arg0: tensor, %arg1: tensor, tensor<3xindex> -> tensor<3xindex> // CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_2_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor // CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_1_]], [[VAR_2_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK: [[VAR_5_:%.+]] = stablehlo.compare LT, [[VAR_3_]], [[VAR_4_]], NOTYPE : (tensor, tensor) -> tensor +// CHECK: [[VAR_5_:%.+]] = stablehlo.compare LT, [[VAR_3_]], [[VAR_4_]] : (tensor, tensor) -> tensor // CHECK: return [[VAR_5_]] : tensor // CHECK: } @@ -447,7 +447,7 @@ func.func @test_leakyrelu_dynamic(%arg0 : tensor) -> tensor // CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.multiply [[PARAM_0_]], [[VAR_3_]] : tensor // CHECK-DAG: [[VAR_5_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<2xindex> // CHECK: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_0_]], [[VAR_5_]], dims = [] : (tensor, tensor<2xindex>) -> tensor -// CHECK: [[VAR_7_:%.+]] = stablehlo.compare GT, [[PARAM_0_]], [[VAR_6_]], NOTYPE : (tensor, tensor) -> tensor +// CHECK: [[VAR_7_:%.+]] = stablehlo.compare GT, [[PARAM_0_]], [[VAR_6_]] : (tensor, tensor) -> tensor // CHECK: [[VAR_8_:%.+]] = stablehlo.select [[VAR_7_]], [[PARAM_0_]], [[VAR_4_]] : tensor, tensor // CHECK: return [[VAR_8_]] : tensor // CHECK: } @@ -469,7 +469,7 @@ func.func @test_prelu_dynamic(%arg0 : tensor, %arg1: tensor<10x1 // CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.multiply [[VAR_4_]], [[VAR_5_]] : tensor // CHECK-DAG: [[VAR_7_:%.+]] = shape.shape_of [[VAR_4_]] : tensor -> tensor<4xindex> // CHECK: [[VAR_8_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_0_]], [[VAR_7_]], dims = [] : (tensor, tensor<4xindex>) -> tensor -// CHECK: [[VAR_9_:%.+]] = stablehlo.compare GT, [[VAR_4_]], [[VAR_8_]], NOTYPE : (tensor, tensor) -> tensor +// CHECK: [[VAR_9_:%.+]] = stablehlo.compare GT, [[VAR_4_]], [[VAR_8_]] : (tensor, tensor) -> tensor // CHECK: [[VAR_10_:%.+]] = stablehlo.select [[VAR_9_]], [[VAR_4_]], [[VAR_6_]] : tensor, tensor // CHECK: return [[VAR_10_]] : tensor // CHECK: } diff --git a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir index 2a8894d946..d74f7288f6 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir @@ -32,45 +32,46 @@ func.func @test_softmax_dynamic(%arg0 : tensor) -> tensor) -> () } -// CHECK-LABEL: func.func @test_softmax_dynamic -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index -// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index -// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index -// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex> -// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> -// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor -// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor -// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index -// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index -// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index -// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex> -// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor -// CHECK: return [[VAR_28_]] : tensor -// CHECK: } +//TODO: Renable dynamic shape test +// func.func @test_softmax_dynamic +// ([[PARAM_0_:%.+]]: tensor) -> tensor { +// [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// [[CST_2_:%.+]] = arith.constant 2 : index +// [[CST_1_:%.+]] = arith.constant 1 : index +// [[CST_0_:%.+]] = arith.constant 0 : index +// [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor +// separator of consecutive DAGs +// [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor, tensor) -> tensor +// [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> +// separator of consecutive DAGs +// [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index +// [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index +// [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index +// [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex> +// [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor, tensor<3xindex>) -> tensor +// [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> +// [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor -> tensor<3xindex> +// [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> +// [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor +// [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor +// [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor, tensor) -> tensor +// [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> +// separator of consecutive DAGs +// [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index +// [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index +// [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index +// [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex> +// [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor, tensor<3xindex>) -> tensor +// [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> +// [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor -> tensor<3xindex> +// [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> +// [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor +// return [[VAR_28_]] : tensor +// } // ----- diff --git a/test/mlir/conversion/onnx_to_stablehlo/RNN/LSTM-loop.mlir b/test/mlir/conversion/onnx_to_stablehlo/RNN/LSTM-loop.mlir index 80a12021c4..d1c9922a5d 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/RNN/LSTM-loop.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/RNN/LSTM-loop.mlir @@ -76,7 +76,7 @@ func.func @test_lstm_loop(%arg0 : tensor<128x16x512xf32>, %arg1 : tensor<2x2048x // CHECK-DAG: [[VAR_48_:%.+]] = stablehlo.slice [[VAR_32_]] [1792:2048] : (tensor<2048xf32>) -> tensor<256xf32> // CHECK-DAG: [[VAR_49_:%.+]]:4 = stablehlo.while([[VAR_iterArg_:%.+]] = [[VAR_c_7_]], [[VAR_iterArg_9_:%.+]] = [[VAR_cst_5_]], [[VAR_iterArg_10_:%.+]] = [[VAR_10_]], [[VAR_iterArg_11_:%.+]] = [[VAR_12_]]) : tensor<1xi64>, tensor<128x1x16x256xf32>, tensor<16x256xf32>, tensor<16x256xf32> // CHECK: cond { -// CHECK: [[VAR_52_:%.+]] = stablehlo.compare LT, [[VAR_iterArg_]], [[VAR_c_4_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_52_:%.+]] = stablehlo.compare LT, [[VAR_iterArg_]], [[VAR_c_4_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_53_:%.+]] = stablehlo.reshape [[VAR_52_]] : (tensor<1xi1>) -> tensor // CHECK: stablehlo.return [[VAR_53_]] : tensor // CHECK: } do { @@ -156,7 +156,7 @@ func.func @test_lstm_loop(%arg0 : tensor<128x16x512xf32>, %arg1 : tensor<2x2048x // CHECK: } // CHECK: [[VAR_50_:%.+]]:4 = stablehlo.while([[VAR_iterArg_1_:%.+]] = [[VAR_c_]], [[VAR_iterArg_9_1_:%.+]] = [[VAR_c_]]st_5, [[VAR_iterArg_10_1_:%.+]] = [[VAR_14_]], [[VAR_iterArg_11_1_:%.+]] = [[VAR_16_]]) : tensor<1xi64>, tensor<128x1x16x256xf32>, tensor<16x256xf32>, tensor<16x256xf32> // CHECK: cond { -// CHECK: [[VAR_52_2_:%.+]] = stablehlo.compare GE, [[VAR_iterArg_1_]], [[VAR_c_7_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_52_2_:%.+]] = stablehlo.compare GE, [[VAR_iterArg_1_]], [[VAR_c_7_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_53_2_:%.+]] = stablehlo.reshape [[VAR_52_2_]] : (tensor<1xi1>) -> tensor // CHECK: stablehlo.return [[VAR_53_2_]] : tensor // CHECK: } do { diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/ArgMax.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/ArgMax.mlir index 0137e17eb5..d619c36369 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/ArgMax.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/ArgMax.mlir @@ -13,9 +13,9 @@ func.func @test_argmax_verifier_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xi64> // CHECK: [[VAR_3_:%.+]] = stablehlo.dynamic_iota [[VAR_0_]], dim = 3 : (tensor<4xindex>) -> tensor<5x5x1x32xi64> // CHECK: [[VAR_4_:%.+]]:2 = stablehlo.reduce(%arg0 init: [[VAR_2_]]), (%1 init: [[VAR_1_]]) across dimensions = [3] : (tensor<5x5x1x32xf32>, tensor<5x5x1x32xi64>, tensor, tensor) -> (tensor<5x5x1xf32>, tensor<5x5x1xi64>) // CHECK: reducer(%arg1: tensor, %arg3: tensor) (%arg2: tensor, %arg4: tensor) { -// CHECK: [[VAR_6_:%.+]] = stablehlo.compare GE, %arg1, %arg3, NOTYPE : (tensor, tensor) -> tensor +// CHECK: [[VAR_6_:%.+]] = stablehlo.compare GE, %arg1, %arg3 : (tensor, tensor) -> tensor // CHECK-DAG: [[VAR_7_:%.+]] = stablehlo.select [[VAR_6_]], %arg1, %arg3 : tensor, tensor -// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.compare EQ, %arg1, %arg3, NOTYPE : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.compare EQ, %arg1, %arg3 : (tensor, tensor) -> tensor // CHECK-DAG: [[VAR_9_:%.+]] = stablehlo.minimum %arg2, %arg4 : tensor // CHECK-DAG: [[VAR_10_:%.+]] = stablehlo.select [[VAR_6_]], %arg2, %arg4 : tensor, tensor // CHECK: [[VAR_11_:%.+]] = stablehlo.select [[VAR_8_]], [[VAR_9_]], [[VAR_10_]] : tensor, tensor @@ -43,9 +43,9 @@ func.func @test_argmax_verifier_2(%arg0 : tensor<5x?x1x32xf32>) -> tensor<*xi64> // CHECK: [[VAR_3_:%.+]] = stablehlo.dynamic_iota [[VAR_2_]], dim = 3 : (tensor<4xindex>) -> tensor<5x?x1x32xi64> // CHECK: [[VAR_4_:%.+]]:2 = stablehlo.reduce([[PARAM_0_]] init: [[VAR_0_]]), ([[VAR_3_]] init: [[VAR_1_]]) across dimensions = [3] : (tensor<5x?x1x32xf32>, tensor<5x?x1x32xi64>, tensor, tensor) -> (tensor<5x?x1xf32>, tensor<5x?x1xi64>) // CHECK: reducer(%arg1: tensor, %arg3: tensor) (%arg2: tensor, %arg4: tensor) { -// CHECK: [[VAR_11_:%.+]] = stablehlo.compare GE, %arg1, %arg3, NOTYPE : (tensor, tensor) -> tensor +// CHECK: [[VAR_11_:%.+]] = stablehlo.compare GE, %arg1, %arg3 : (tensor, tensor) -> tensor // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.select [[VAR_11_]], %arg1, %arg3 : tensor, tensor -// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.compare EQ, %arg1, %arg3, NOTYPE : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.compare EQ, %arg1, %arg3 : (tensor, tensor) -> tensor // CHECK-DAG: [[VAR_14_:%.+]] = stablehlo.minimum %arg2, %arg4 : tensor // CHECK-DAG: [[VAR_15_:%.+]] = stablehlo.select [[VAR_11_]], %arg2, %arg4 : tensor, tensor // CHECK: [[VAR_16_:%.+]] = stablehlo.select [[VAR_13_]], [[VAR_14_]], [[VAR_15_]] : tensor, tensor diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Gather.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Gather.mlir index 217081c416..ae427d590f 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Gather.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Gather.mlir @@ -12,7 +12,7 @@ func.func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { // CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64> // CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<3> : tensor<2x2xi64> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]] : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> // CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[VAR_0_]], [[VAR_2_]] : tensor<2x2xi64> // CHECK: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[VAR_0_]] : tensor<2x2xi1>, tensor<2x2xi64> // CHECK: [[VAR_6_:%.+]] = "stablehlo.torch_index_select"([[PARAM_0_]], [[VAR_5_]]) <{batch_dims = 0 : i64, dim = 0 : i64}> : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32> @@ -39,7 +39,7 @@ func.func @test_gather_dynamic_axis0(%arg0 : tensor) -> tensor<2x2x?xf3 // CHECK-DAG: [[DIM_TENSOR_:%.+]] = tensor.from_elements [[DIM_CAST_]] : tensor // CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[DIM_TENSOR_]], [[INDICES_SHAPE_]], dims = [] : (tensor, tensor<2xindex>) -> tensor<2x2xi64> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]] : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> // CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[VAR_0_]], [[VAR_2_]] : tensor<2x2xi64> // CHECK: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[VAR_0_]] : tensor<2x2xi1>, tensor<2x2xi64> // CHECK: [[VAR_6_:%.+]] = "stablehlo.torch_index_select"([[PARAM_0_]], [[VAR_5_]]) <{batch_dims = 0 : i64, dim = 0 : i64}> : (tensor, tensor<2x2xi64>) -> tensor<2x2x?xf32> @@ -60,7 +60,7 @@ func.func @test_gather_axis0neg(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { // CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64> // CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<3> : tensor<2x2xi64> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]] : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> // CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[VAR_0_]], [[VAR_2_]] : tensor<2x2xi64> // CHECK: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[VAR_0_]] : tensor<2x2xi1>, tensor<2x2xi64> // CHECK: [[VAR_6_:%.+]] = "stablehlo.torch_index_select"([[PARAM_0_]], [[VAR_5_]]) <{batch_dims = 0 : i64, dim = 0 : i64}> : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32> @@ -81,7 +81,7 @@ func.func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<3x1x2xf32> { // CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<1x2xi64> // CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<3> : tensor<1x2xi64> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]], NOTYPE : (tensor<1x2xi64>, tensor<1x2xi64>) -> tensor<1x2xi1> +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]] : (tensor<1x2xi64>, tensor<1x2xi64>) -> tensor<1x2xi1> // CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[VAR_0_]], [[VAR_2_]] : tensor<1x2xi64> // CHECK: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[VAR_0_]] : tensor<1x2xi1>, tensor<1x2xi64> // CHECK: [[VAR_6_:%.+]] = "stablehlo.torch_index_select"([[PARAM_0_]], [[VAR_5_]]) <{batch_dims = 0 : i64, dim = 1 : i64}> : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir index 1c602d3871..a5893a6833 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir @@ -8,7 +8,7 @@ func.func @main_gather_elements(%arg0: tensor<3x2xf32>, %arg1: tensor<2x2xi64>) // CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<3> : tensor // CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64> // CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.broadcast_in_dim [[VAR_0_]], dims = [] : (tensor) -> tensor<2x2xi64> -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[PARAM_1_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[PARAM_1_]], [[VAR_1_]] : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> // CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[PARAM_1_]], [[VAR_2_]] : tensor<2x2xi64> // CHECK-NEXT: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[PARAM_1_]] : tensor<2x2xi1>, tensor<2x2xi64> // CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_reshape [[VAR_5_]], [[CST_]] : (tensor<2x2xi64>, tensor<3xindex>) -> tensor<2x2x1xi64> diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/OneHot.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/OneHot.mlir index b818e5e70a..dcdab58bf0 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/OneHot.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/OneHot.mlir @@ -16,10 +16,10 @@ func.func @test_onehot(%arg0 : tensor<2x3x4xi64>) -> tensor<*xi64> { // CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.broadcast_in_dim [[VAR_0_]], dims = [] : (tensor) -> tensor<2x3x4x64xi64> // CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.broadcast_in_dim [[VAR_1_]], dims = [0] : (tensor<1xi64>) -> tensor<2x3x4x64xi64> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = stablehlo.compare GE, [[VAR_4_]], [[VAR_5_]], NOTYPE : (tensor<2x3x4x64xi64>, tensor<2x3x4x64xi64>) -> tensor<2x3x4x64xi1> +// CHECK-DAG: [[VAR_7_:%.+]] = stablehlo.compare GE, [[VAR_4_]], [[VAR_5_]] : (tensor<2x3x4x64xi64>, tensor<2x3x4x64xi64>) -> tensor<2x3x4x64xi1> // CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.add [[VAR_4_]], [[VAR_6_]] : tensor<2x3x4x64xi64> // CHECK: [[VAR_9_:%.+]] = stablehlo.select [[VAR_7_]], [[VAR_4_]], [[VAR_8_]] : tensor<2x3x4x64xi1>, tensor<2x3x4x64xi64> -// CHECK-DAG: [[VAR_10_:%.+]] = stablehlo.compare EQ, [[VAR_9_]], [[VAR_3_]], NOTYPE : (tensor<2x3x4x64xi64>, tensor<2x3x4x64xi64>) -> tensor<2x3x4x64xi1> +// CHECK-DAG: [[VAR_10_:%.+]] = stablehlo.compare EQ, [[VAR_9_]], [[VAR_3_]] : (tensor<2x3x4x64xi64>, tensor<2x3x4x64xi64>) -> tensor<2x3x4x64xi1> // CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.slice [[VAR_2_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.slice [[VAR_2_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Slice.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Slice.mlir index 99a92e2ab6..32b61e9749 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Slice.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Slice.mlir @@ -23,7 +23,7 @@ func.func @test_slice_constant_default_axes(%arg0 : tensor<2x4xf32>) -> tensor<* // CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.slice [[VAR_3_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_9_:%.+]] = stablehlo.slice [[VAR_5_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_10_:%.+]] = stablehlo.slice [[VAR_4_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.negate [[VAR_9_]] : tensor<1xi64> // CHECK-DAG: [[VAR_14_:%.+]] = stablehlo.add [[VAR_10_]], [[VAR_7_]] : tensor<1xi64> @@ -34,20 +34,20 @@ func.func @test_slice_constant_default_axes(%arg0 : tensor<2x4xf32>) -> tensor<* // CHECK-DAG: [[VAR_18_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_15_]], [[VAR_10_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_19_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_13_]], [[VAR_9_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_20_:%.+]] = stablehlo.select [[VAR_12_]], [[VAR_16_]], [[PARAM_0_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_22_:%.+]] = stablehlo.select [[VAR_21_]], [[VAR_2_]], [[VAR_18_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_24_:%.+]] = stablehlo.add [[VAR_22_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_25_:%.+]] = stablehlo.select [[VAR_23_]], [[VAR_24_]], [[VAR_22_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.add [[VAR_17_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_28_:%.+]] = stablehlo.select [[VAR_26_]], [[VAR_27_]], [[VAR_17_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_29_:%.+]] = stablehlo.slice [[VAR_3_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_30_:%.+]] = stablehlo.slice [[VAR_5_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_31_:%.+]] = stablehlo.slice [[VAR_4_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_33_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_32_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_34_:%.+]] = stablehlo.negate [[VAR_30_]] : tensor<1xi64> // CHECK-DAG: [[VAR_35_:%.+]] = stablehlo.add [[VAR_31_]], [[VAR_7_]] : tensor<1xi64> @@ -58,13 +58,13 @@ func.func @test_slice_constant_default_axes(%arg0 : tensor<2x4xf32>) -> tensor<* // CHECK-DAG: [[VAR_39_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_36_]], [[VAR_31_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_40_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_34_]], [[VAR_30_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_41_:%.+]] = stablehlo.select [[VAR_33_]], [[VAR_37_]], [[VAR_20_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_43_:%.+]] = stablehlo.select [[VAR_42_]], [[VAR_1_]], [[VAR_39_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_45_:%.+]] = stablehlo.add [[VAR_43_]], [[VAR_1_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_46_:%.+]] = stablehlo.select [[VAR_44_]], [[VAR_45_]], [[VAR_43_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_48_:%.+]] = stablehlo.add [[VAR_38_]], [[VAR_1_]] : tensor<1xi64> // CHECK: [[VAR_49_:%.+]] = stablehlo.select [[VAR_47_]], [[VAR_48_]], [[VAR_38_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_50_:%.+]] = stablehlo.concatenate [[VAR_28_]], [[VAR_49_]], dim = 0 : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> @@ -99,7 +99,7 @@ func.func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor< // CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.slice [[VAR_3_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_9_:%.+]] = stablehlo.slice [[VAR_5_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_10_:%.+]] = stablehlo.slice [[VAR_4_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.negate [[VAR_9_]] : tensor<1xi64> // CHECK-DAG: [[VAR_14_:%.+]] = stablehlo.add [[VAR_10_]], [[VAR_7_]] : tensor<1xi64> @@ -110,20 +110,20 @@ func.func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor< // CHECK-DAG: [[VAR_18_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_15_]], [[VAR_10_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_19_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_13_]], [[VAR_9_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_20_:%.+]] = stablehlo.select [[VAR_12_]], [[VAR_16_]], [[PARAM_0_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_22_:%.+]] = stablehlo.select [[VAR_21_]], [[VAR_2_]], [[VAR_18_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_24_:%.+]] = stablehlo.add [[VAR_22_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_25_:%.+]] = stablehlo.select [[VAR_23_]], [[VAR_24_]], [[VAR_22_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.add [[VAR_17_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_28_:%.+]] = stablehlo.select [[VAR_26_]], [[VAR_27_]], [[VAR_17_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_29_:%.+]] = stablehlo.slice [[VAR_3_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_30_:%.+]] = stablehlo.slice [[VAR_5_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_31_:%.+]] = stablehlo.slice [[VAR_4_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_33_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_32_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_34_:%.+]] = stablehlo.negate [[VAR_30_]] : tensor<1xi64> // CHECK-DAG: [[VAR_35_:%.+]] = stablehlo.add [[VAR_31_]], [[VAR_7_]] : tensor<1xi64> @@ -134,13 +134,13 @@ func.func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor< // CHECK-DAG: [[VAR_39_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_36_]], [[VAR_31_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_40_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_34_]], [[VAR_30_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_41_:%.+]] = stablehlo.select [[VAR_33_]], [[VAR_37_]], [[VAR_20_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_43_:%.+]] = stablehlo.select [[VAR_42_]], [[VAR_1_]], [[VAR_39_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_45_:%.+]] = stablehlo.add [[VAR_43_]], [[VAR_1_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_46_:%.+]] = stablehlo.select [[VAR_44_]], [[VAR_45_]], [[VAR_43_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_48_:%.+]] = stablehlo.add [[VAR_38_]], [[VAR_1_]] : tensor<1xi64> // CHECK: [[VAR_49_:%.+]] = stablehlo.select [[VAR_47_]], [[VAR_48_]], [[VAR_38_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_50_:%.+]] = stablehlo.concatenate [[VAR_28_]], [[VAR_49_]], dim = 0 : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> @@ -175,7 +175,7 @@ func.func @test_slice_all_constant(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { // CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.slice [[VAR_3_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_9_:%.+]] = stablehlo.slice [[VAR_5_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_10_:%.+]] = stablehlo.slice [[VAR_4_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.negate [[VAR_9_]] : tensor<1xi64> // CHECK-DAG: [[VAR_14_:%.+]] = stablehlo.add [[VAR_10_]], [[VAR_7_]] : tensor<1xi64> @@ -186,20 +186,20 @@ func.func @test_slice_all_constant(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { // CHECK-DAG: [[VAR_18_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_15_]], [[VAR_10_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_19_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_13_]], [[VAR_9_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_20_:%.+]] = stablehlo.select [[VAR_12_]], [[VAR_16_]], [[PARAM_0_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_22_:%.+]] = stablehlo.select [[VAR_21_]], [[VAR_2_]], [[VAR_18_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_24_:%.+]] = stablehlo.add [[VAR_22_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_25_:%.+]] = stablehlo.select [[VAR_23_]], [[VAR_24_]], [[VAR_22_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.add [[VAR_17_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_28_:%.+]] = stablehlo.select [[VAR_26_]], [[VAR_27_]], [[VAR_17_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_29_:%.+]] = stablehlo.slice [[VAR_3_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_30_:%.+]] = stablehlo.slice [[VAR_5_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_31_:%.+]] = stablehlo.slice [[VAR_4_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_33_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_32_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_34_:%.+]] = stablehlo.negate [[VAR_30_]] : tensor<1xi64> // CHECK-DAG: [[VAR_35_:%.+]] = stablehlo.add [[VAR_31_]], [[VAR_7_]] : tensor<1xi64> @@ -210,13 +210,13 @@ func.func @test_slice_all_constant(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { // CHECK-DAG: [[VAR_39_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_36_]], [[VAR_31_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_40_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_34_]], [[VAR_30_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_41_:%.+]] = stablehlo.select [[VAR_33_]], [[VAR_37_]], [[VAR_20_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_43_:%.+]] = stablehlo.select [[VAR_42_]], [[VAR_1_]], [[VAR_39_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_45_:%.+]] = stablehlo.add [[VAR_43_]], [[VAR_1_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_46_:%.+]] = stablehlo.select [[VAR_44_]], [[VAR_45_]], [[VAR_43_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_48_:%.+]] = stablehlo.add [[VAR_38_]], [[VAR_1_]] : tensor<1xi64> // CHECK: [[VAR_49_:%.+]] = stablehlo.select [[VAR_47_]], [[VAR_48_]], [[VAR_38_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_50_:%.+]] = stablehlo.concatenate [[VAR_28_]], [[VAR_49_]], dim = 0 : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> @@ -251,7 +251,7 @@ func.func @test_slice_all_constant_negative(%arg0 : tensor<2x4xf32>) -> tensor<* // CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.slice [[VAR_3_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_9_:%.+]] = stablehlo.slice [[VAR_5_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_10_:%.+]] = stablehlo.slice [[VAR_4_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.negate [[VAR_9_]] : tensor<1xi64> // CHECK-DAG: [[VAR_14_:%.+]] = stablehlo.add [[VAR_10_]], [[VAR_7_]] : tensor<1xi64> @@ -262,20 +262,20 @@ func.func @test_slice_all_constant_negative(%arg0 : tensor<2x4xf32>) -> tensor<* // CHECK-DAG: [[VAR_18_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_15_]], [[VAR_10_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_19_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_13_]], [[VAR_9_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_20_:%.+]] = stablehlo.select [[VAR_12_]], [[VAR_16_]], [[PARAM_0_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_22_:%.+]] = stablehlo.select [[VAR_21_]], [[VAR_2_]], [[VAR_18_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_24_:%.+]] = stablehlo.add [[VAR_22_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_25_:%.+]] = stablehlo.select [[VAR_23_]], [[VAR_24_]], [[VAR_22_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.add [[VAR_17_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_28_:%.+]] = stablehlo.select [[VAR_26_]], [[VAR_27_]], [[VAR_17_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_29_:%.+]] = stablehlo.slice [[VAR_3_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_30_:%.+]] = stablehlo.slice [[VAR_5_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_31_:%.+]] = stablehlo.slice [[VAR_4_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_33_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_32_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_34_:%.+]] = stablehlo.negate [[VAR_30_]] : tensor<1xi64> // CHECK-DAG: [[VAR_35_:%.+]] = stablehlo.add [[VAR_31_]], [[VAR_7_]] : tensor<1xi64> @@ -286,13 +286,13 @@ func.func @test_slice_all_constant_negative(%arg0 : tensor<2x4xf32>) -> tensor<* // CHECK-DAG: [[VAR_39_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_36_]], [[VAR_31_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_40_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_34_]], [[VAR_30_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_41_:%.+]] = stablehlo.select [[VAR_33_]], [[VAR_37_]], [[VAR_20_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_43_:%.+]] = stablehlo.select [[VAR_42_]], [[VAR_1_]], [[VAR_39_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_45_:%.+]] = stablehlo.add [[VAR_43_]], [[VAR_1_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_46_:%.+]] = stablehlo.select [[VAR_44_]], [[VAR_45_]], [[VAR_43_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_48_:%.+]] = stablehlo.add [[VAR_38_]], [[VAR_1_]] : tensor<1xi64> // CHECK: [[VAR_49_:%.+]] = stablehlo.select [[VAR_47_]], [[VAR_48_]], [[VAR_38_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_50_:%.+]] = stablehlo.concatenate [[VAR_28_]], [[VAR_49_]], dim = 0 : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> @@ -327,7 +327,7 @@ func.func @test_slice_all_constant_end_outofbound(%arg0 : tensor<2x4xf32>) -> te // CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.slice [[VAR_3_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_9_:%.+]] = stablehlo.slice [[VAR_5_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_10_:%.+]] = stablehlo.slice [[VAR_4_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.negate [[VAR_9_]] : tensor<1xi64> // CHECK-DAG: [[VAR_14_:%.+]] = stablehlo.add [[VAR_10_]], [[VAR_7_]] : tensor<1xi64> @@ -338,20 +338,20 @@ func.func @test_slice_all_constant_end_outofbound(%arg0 : tensor<2x4xf32>) -> te // CHECK-DAG: [[VAR_18_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_15_]], [[VAR_10_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_19_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_13_]], [[VAR_9_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_20_:%.+]] = stablehlo.select [[VAR_12_]], [[VAR_16_]], [[PARAM_0_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_22_:%.+]] = stablehlo.select [[VAR_21_]], [[VAR_2_]], [[VAR_18_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_24_:%.+]] = stablehlo.add [[VAR_22_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_25_:%.+]] = stablehlo.select [[VAR_23_]], [[VAR_24_]], [[VAR_22_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.add [[VAR_17_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_28_:%.+]] = stablehlo.select [[VAR_26_]], [[VAR_27_]], [[VAR_17_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_29_:%.+]] = stablehlo.slice [[VAR_3_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_30_:%.+]] = stablehlo.slice [[VAR_5_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_31_:%.+]] = stablehlo.slice [[VAR_4_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_33_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_32_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_34_:%.+]] = stablehlo.negate [[VAR_30_]] : tensor<1xi64> // CHECK-DAG: [[VAR_35_:%.+]] = stablehlo.add [[VAR_31_]], [[VAR_7_]] : tensor<1xi64> @@ -362,13 +362,13 @@ func.func @test_slice_all_constant_end_outofbound(%arg0 : tensor<2x4xf32>) -> te // CHECK-DAG: [[VAR_39_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_36_]], [[VAR_31_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_40_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_34_]], [[VAR_30_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_41_:%.+]] = stablehlo.select [[VAR_33_]], [[VAR_37_]], [[VAR_20_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_43_:%.+]] = stablehlo.select [[VAR_42_]], [[VAR_1_]], [[VAR_39_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_45_:%.+]] = stablehlo.add [[VAR_43_]], [[VAR_1_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_46_:%.+]] = stablehlo.select [[VAR_44_]], [[VAR_45_]], [[VAR_43_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_48_:%.+]] = stablehlo.add [[VAR_38_]], [[VAR_1_]] : tensor<1xi64> // CHECK: [[VAR_49_:%.+]] = stablehlo.select [[VAR_47_]], [[VAR_48_]], [[VAR_38_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_50_:%.+]] = stablehlo.concatenate [[VAR_28_]], [[VAR_49_]], dim = 0 : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> @@ -403,7 +403,7 @@ func.func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> te // CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.slice [[VAR_3_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_9_:%.+]] = stablehlo.slice [[VAR_5_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_10_:%.+]] = stablehlo.slice [[VAR_4_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_11_:%.+]] = stablehlo.compare LT, [[VAR_9_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.negate [[VAR_9_]] : tensor<1xi64> // CHECK-DAG: [[VAR_14_:%.+]] = stablehlo.add [[VAR_10_]], [[VAR_7_]] : tensor<1xi64> @@ -414,20 +414,20 @@ func.func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> te // CHECK-DAG: [[VAR_18_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_15_]], [[VAR_10_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_19_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_13_]], [[VAR_9_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_20_:%.+]] = stablehlo.select [[VAR_12_]], [[VAR_16_]], [[PARAM_0_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_21_:%.+]] = stablehlo.compare GT, [[VAR_18_]], [[VAR_2_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_22_:%.+]] = stablehlo.select [[VAR_21_]], [[VAR_2_]], [[VAR_18_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.compare LT, [[VAR_22_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_24_:%.+]] = stablehlo.add [[VAR_22_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_25_:%.+]] = stablehlo.select [[VAR_23_]], [[VAR_24_]], [[VAR_22_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_17_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.add [[VAR_17_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_28_:%.+]] = stablehlo.select [[VAR_26_]], [[VAR_27_]], [[VAR_17_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_29_:%.+]] = stablehlo.slice [[VAR_3_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_30_:%.+]] = stablehlo.slice [[VAR_5_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_31_:%.+]] = stablehlo.slice [[VAR_4_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_32_:%.+]] = stablehlo.compare LT, [[VAR_30_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_33_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_32_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<2xindex>) -> tensor<2x4xi1> // CHECK-DAG: [[VAR_34_:%.+]] = stablehlo.negate [[VAR_30_]] : tensor<1xi64> // CHECK-DAG: [[VAR_35_:%.+]] = stablehlo.add [[VAR_31_]], [[VAR_7_]] : tensor<1xi64> @@ -438,13 +438,13 @@ func.func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> te // CHECK-DAG: [[VAR_39_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_36_]], [[VAR_31_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_40_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_34_]], [[VAR_30_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_41_:%.+]] = stablehlo.select [[VAR_33_]], [[VAR_37_]], [[VAR_20_]] : tensor<2x4xi1>, tensor<2x4xf32> -// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_42_:%.+]] = stablehlo.compare GT, [[VAR_39_]], [[VAR_1_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_43_:%.+]] = stablehlo.select [[VAR_42_]], [[VAR_1_]], [[VAR_39_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.compare LT, [[VAR_43_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_45_:%.+]] = stablehlo.add [[VAR_43_]], [[VAR_1_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_46_:%.+]] = stablehlo.select [[VAR_44_]], [[VAR_45_]], [[VAR_43_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_38_]], [[VAR_6_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_48_:%.+]] = stablehlo.add [[VAR_38_]], [[VAR_1_]] : tensor<1xi64> // CHECK: [[VAR_49_:%.+]] = stablehlo.select [[VAR_47_]], [[VAR_48_]], [[VAR_38_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_50_:%.+]] = stablehlo.concatenate [[VAR_28_]], [[VAR_49_]], dim = 0 : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> @@ -483,7 +483,7 @@ func.func @dyntest_slice_constant_dynshape_not_spliced(%arg0 : tensor // CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.slice [[VAR_2_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.slice [[VAR_2_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.slice [[VAR_3_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_14_:%.+]] = stablehlo.compare LT, [[VAR_12_]], [[VAR_4_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_14_:%.+]] = stablehlo.compare LT, [[VAR_12_]], [[VAR_4_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_15_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_14_]], [[VAR_6_]], dims = [0] : (tensor<1xi1>, tensor<3xindex>) -> tensor // CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.negate [[VAR_12_]] : tensor<1xi64> // CHECK-DAG: [[VAR_17_:%.+]] = stablehlo.add [[VAR_13_]], [[VAR_5_]] : tensor<1xi64> @@ -494,20 +494,20 @@ func.func @dyntest_slice_constant_dynshape_not_spliced(%arg0 : tensor // CHECK-DAG: [[VAR_21_:%.+]] = stablehlo.select [[VAR_14_]], [[VAR_18_]], [[VAR_13_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.select [[VAR_14_]], [[VAR_16_]], [[VAR_12_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.select [[VAR_15_]], [[VAR_19_]], [[PARAM_0_]] : tensor, tensor -// CHECK: [[VAR_24_:%.+]] = stablehlo.compare GT, [[VAR_21_]], [[VAR_1_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_24_:%.+]] = stablehlo.compare GT, [[VAR_21_]], [[VAR_1_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_25_:%.+]] = stablehlo.select [[VAR_24_]], [[VAR_1_]], [[VAR_21_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_25_]], [[VAR_4_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.compare LT, [[VAR_25_]], [[VAR_4_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.add [[VAR_25_]], [[VAR_1_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_28_:%.+]] = stablehlo.select [[VAR_26_]], [[VAR_27_]], [[VAR_25_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_29_:%.+]] = stablehlo.compare LT, [[VAR_20_]], [[VAR_4_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_29_:%.+]] = stablehlo.compare LT, [[VAR_20_]], [[VAR_4_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_30_:%.+]] = stablehlo.add [[VAR_20_]], [[VAR_1_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_31_:%.+]] = stablehlo.select [[VAR_29_]], [[VAR_30_]], [[VAR_20_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_32_:%.+]] = stablehlo.slice [[VAR_2_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_33_:%.+]] = stablehlo.slice [[VAR_2_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_34_:%.+]] = stablehlo.slice [[VAR_3_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_35_:%.+]] = stablehlo.compare LT, [[VAR_33_]], [[VAR_4_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_35_:%.+]] = stablehlo.compare LT, [[VAR_33_]], [[VAR_4_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_36_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_35_]], [[VAR_6_]], dims = [0] : (tensor<1xi1>, tensor<3xindex>) -> tensor // CHECK-DAG: [[VAR_37_:%.+]] = stablehlo.negate [[VAR_33_]] : tensor<1xi64> // CHECK-DAG: [[VAR_38_:%.+]] = stablehlo.add [[VAR_34_]], [[VAR_5_]] : tensor<1xi64> @@ -518,13 +518,13 @@ func.func @dyntest_slice_constant_dynshape_not_spliced(%arg0 : tensor // CHECK-DAG: [[VAR_42_:%.+]] = stablehlo.select [[VAR_35_]], [[VAR_39_]], [[VAR_34_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_43_:%.+]] = stablehlo.select [[VAR_35_]], [[VAR_37_]], [[VAR_33_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.select [[VAR_36_]], [[VAR_40_]], [[VAR_23_]] : tensor, tensor -// CHECK: [[VAR_45_:%.+]] = stablehlo.compare GT, [[VAR_42_]], [[VAR_0_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_45_:%.+]] = stablehlo.compare GT, [[VAR_42_]], [[VAR_0_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_46_:%.+]] = stablehlo.select [[VAR_45_]], [[VAR_0_]], [[VAR_42_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_46_]], [[VAR_4_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.compare LT, [[VAR_46_]], [[VAR_4_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_48_:%.+]] = stablehlo.add [[VAR_46_]], [[VAR_0_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_49_:%.+]] = stablehlo.select [[VAR_47_]], [[VAR_48_]], [[VAR_46_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_50_:%.+]] = stablehlo.compare LT, [[VAR_41_]], [[VAR_4_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_50_:%.+]] = stablehlo.compare LT, [[VAR_41_]], [[VAR_4_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_51_:%.+]] = stablehlo.add [[VAR_41_]], [[VAR_0_]] : tensor<1xi64> // CHECK: [[VAR_52_:%.+]] = stablehlo.select [[VAR_50_]], [[VAR_51_]], [[VAR_41_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_53_:%.+]] = stablehlo.concatenate [[VAR_4_]], [[VAR_31_]], [[VAR_52_]], dim = 0 : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> @@ -555,7 +555,7 @@ func.func @compute_slice_all_dyn(%arg0 : tensor<2xi64>, %arg1 : tensor<2xi64>, % // CHECK-DAG: [[VAR_7_:%.+]] = stablehlo.slice [[PARAM_0_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.slice [[PARAM_2_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_9_:%.+]] = stablehlo.slice [[PARAM_1_]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_10_:%.+]] = stablehlo.compare LT, [[VAR_8_]], [[VAR_5_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_10_:%.+]] = stablehlo.compare LT, [[VAR_8_]], [[VAR_5_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_10_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<3xindex>) -> tensor<3x4x5xi1> // CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.negate [[VAR_8_]] : tensor<1xi64> // CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.add [[VAR_9_]], [[VAR_6_]] : tensor<1xi64> @@ -566,20 +566,20 @@ func.func @compute_slice_all_dyn(%arg0 : tensor<2xi64>, %arg1 : tensor<2xi64>, % // CHECK-DAG: [[VAR_17_:%.+]] = stablehlo.select [[VAR_10_]], [[VAR_14_]], [[VAR_9_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_18_:%.+]] = stablehlo.select [[VAR_10_]], [[VAR_12_]], [[VAR_8_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_19_:%.+]] = stablehlo.select [[VAR_11_]], [[VAR_15_]], [[VAR_4_]] : tensor<3x4x5xi1>, tensor<3x4x5xi64> -// CHECK: [[VAR_20_:%.+]] = stablehlo.compare GT, [[VAR_17_]], [[VAR_2_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_20_:%.+]] = stablehlo.compare GT, [[VAR_17_]], [[VAR_2_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_21_:%.+]] = stablehlo.select [[VAR_20_]], [[VAR_2_]], [[VAR_17_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.compare LT, [[VAR_21_]], [[VAR_5_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.compare LT, [[VAR_21_]], [[VAR_5_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_23_:%.+]] = stablehlo.add [[VAR_21_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_24_:%.+]] = stablehlo.select [[VAR_22_]], [[VAR_23_]], [[VAR_21_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_25_:%.+]] = stablehlo.compare LT, [[VAR_16_]], [[VAR_5_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_25_:%.+]] = stablehlo.compare LT, [[VAR_16_]], [[VAR_5_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.add [[VAR_16_]], [[VAR_2_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.select [[VAR_25_]], [[VAR_26_]], [[VAR_16_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_28_:%.+]] = stablehlo.slice [[PARAM_0_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_29_:%.+]] = stablehlo.slice [[PARAM_2_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> // CHECK-DAG: [[VAR_30_:%.+]] = stablehlo.slice [[PARAM_1_]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> -// CHECK: [[VAR_31_:%.+]] = stablehlo.compare LT, [[VAR_29_]], [[VAR_5_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_31_:%.+]] = stablehlo.compare LT, [[VAR_29_]], [[VAR_5_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_32_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_31_]], [[VAR_0_]], dims = [0] : (tensor<1xi1>, tensor<3xindex>) -> tensor<3x4x5xi1> // CHECK-DAG: [[VAR_33_:%.+]] = stablehlo.negate [[VAR_29_]] : tensor<1xi64> // CHECK-DAG: [[VAR_34_:%.+]] = stablehlo.add [[VAR_30_]], [[VAR_6_]] : tensor<1xi64> @@ -590,13 +590,13 @@ func.func @compute_slice_all_dyn(%arg0 : tensor<2xi64>, %arg1 : tensor<2xi64>, % // CHECK-DAG: [[VAR_38_:%.+]] = stablehlo.select [[VAR_31_]], [[VAR_35_]], [[VAR_30_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_39_:%.+]] = stablehlo.select [[VAR_31_]], [[VAR_33_]], [[VAR_29_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_40_:%.+]] = stablehlo.select [[VAR_32_]], [[VAR_36_]], [[VAR_19_]] : tensor<3x4x5xi1>, tensor<3x4x5xi64> -// CHECK: [[VAR_41_:%.+]] = stablehlo.compare GT, [[VAR_38_]], [[VAR_1_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK: [[VAR_41_:%.+]] = stablehlo.compare GT, [[VAR_38_]], [[VAR_1_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK: [[VAR_42_:%.+]] = stablehlo.select [[VAR_41_]], [[VAR_1_]], [[VAR_38_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_43_:%.+]] = stablehlo.compare LT, [[VAR_42_]], [[VAR_5_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_43_:%.+]] = stablehlo.compare LT, [[VAR_42_]], [[VAR_5_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_44_:%.+]] = stablehlo.add [[VAR_42_]], [[VAR_1_]] : tensor<1xi64> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_45_:%.+]] = stablehlo.select [[VAR_43_]], [[VAR_44_]], [[VAR_42_]] : tensor<1xi1>, tensor<1xi64> -// CHECK-DAG: [[VAR_46_:%.+]] = stablehlo.compare LT, [[VAR_37_]], [[VAR_5_]], NOTYPE : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> +// CHECK-DAG: [[VAR_46_:%.+]] = stablehlo.compare LT, [[VAR_37_]], [[VAR_5_]] : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> // CHECK-DAG: [[VAR_47_:%.+]] = stablehlo.add [[VAR_37_]], [[VAR_1_]] : tensor<1xi64> // CHECK: [[VAR_48_:%.+]] = stablehlo.select [[VAR_46_]], [[VAR_47_]], [[VAR_37_]] : tensor<1xi1>, tensor<1xi64> // CHECK-DAG: [[VAR_49_:%.+]] = stablehlo.concatenate [[VAR_5_]], [[VAR_27_]], [[VAR_48_]], dim = 0 : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> diff --git a/test/mlir/driver/compile_phases.mlir b/test/mlir/driver/compile_phases.mlir index a79bdc7d51..3e94ccbfb0 100644 --- a/test/mlir/driver/compile_phases.mlir +++ b/test/mlir/driver/compile_phases.mlir @@ -1,10 +1,11 @@ // RUN: onnx-mlir %s -o %t| FileCheck %s && rm %t.so -// CHECK: [1/5] {{.*}} Importing ONNX Model to MLIR Module -// CHECK: [2/5] {{.*}} Compiling and Optimizing MLIR Module -// CHECK: [3/5] {{.*}} Translating MLIR Module to LLVM and Generating LLVM Optimized Bitcode -// CHECK: [4/5] {{.*}} Generating Object from LLVM Bitcode -// CHECK: [5/5] {{.*}} Linking and Generating the Output Shared Library +// CHECK: [1/6] {{.*}} Importing ONNX Model to MLIR Module from +// CHECK: [2/6] {{.*}} Compiling and Optimizing MLIR Module +// CHECK: [3/6] {{.*}} Translating MLIR Module to LLVM and Generating LLVM Optimized Bitcode +// CHECK: [4/6] {{.*}} Generating Object from LLVM Bitcode +// CHECK: [5/6] {{.*}} Linking and Generating the Output Shared Library +// CHECK: [6/6] {{.*}} Compilation completed module { func.func @main_graph(%arg0: tensor) -> tensor { onnx.Return %arg0 : tensor diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index cfd7244565..c19b1974f7 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -318,6 +318,120 @@ func.func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> } +// ----- + +// CHECK-LABEL: @test_ceil() -> tensor<3x2xbf16> +func.func @test_ceil() -> tensor<3x2xbf16> { + // Test Positive, Negative, Zero, NaN, +Inf, -Inf + %0 = onnx.Constant dense<[[12.2, -12.2], [0.0, 0x7FC0], [0x7F80, 0xFF80]]> : tensor<3x2xbf16> + %1 = "onnx.Ceil"(%0) : (tensor<3x2xbf16>) -> tensor<3x2xbf16> + "onnx.Return"(%1) : (tensor<3x2xbf16>) -> () + // CHECK: onnx.Constant dense<{{.}}[1.300000e+01, -1.200000e+01], [0.000000e+00, 0x7FC0], [0x7F80, 0xFF80]]> + // CHECK-NOT: "onnx.Ceil" +} + +// ----- + +// CHECK-LABEL: @test_cos() -> tensor<3x2xf32> +func.func @test_cos() -> tensor<3x2xf32> { + // Test Positive, Negative, Zero, NaN, +Inf, -Inf + // Results: Positive, Positive, One, NaN, +/-NaN, +/-NaN + // Note: Implementations of cos can output either NaN and -NaN for +/-Inf numbers. + %0 = onnx.Constant dense<[[0.625, -0.625], [0.0, 0x7FC00000], [0x7F800000, 0xFF800000]]> : tensor<3x2xf32> + %1 = "onnx.Cos"(%0) : (tensor<3x2xf32>) -> tensor<3x2xf32> + "onnx.Return"(%1) : (tensor<3x2xf32>) -> () + // CHECK: onnx.Constant dense<{{.}}[0.810963094, 0.810963094], [1.000000e+00, 0x7FC00000], [0x{{F|7}}FC00000, 0x{{F|7}}FC00000]]> + // CHECK-NOT: "onnx.Cos" +} + +// ----- + +// CHECK-LABEL: @test_erf() -> tensor<3x2xbf16> +func.func @test_erf() -> tensor<3x2xbf16> { + // Test Positive, Negative, Zero, NaN, +Inf, -Inf + %0 = onnx.Constant dense<[[0.0625, -1.0], [0.0, 0x7FC0], [0x7F80, 0xFF80]]> : tensor<3x2xbf16> + %1 = "onnx.Erf"(%0) : (tensor<3x2xbf16>) -> tensor<3x2xbf16> + "onnx.Return"(%1) : (tensor<3x2xbf16>) -> () + // CHECK: onnx.Constant dense<{{.}}[7.031250e-02, -8.437500e-01], [0.000000e+00, 0x7FC0], [1.000000e+00, -1.000000e+00]]> + // CHECK-NOT: "onnx.Erf" +} + +// ----- + +// CHECK-LABEL: @test_exp() -> tensor<3x2xbf16> +func.func @test_exp() -> tensor<3x2xbf16> { + // Test Positive, Negative, Zero, NaN, +Inf, -Inf + %0 = onnx.Constant dense<[[0.0625, -1.0], [0.0, 0x7FC0], [0x7F80, 0xFF80]]> : tensor<3x2xbf16> + %1 = "onnx.Exp"(%0) : (tensor<3x2xbf16>) -> tensor<3x2xbf16> + "onnx.Return"(%1) : (tensor<3x2xbf16>) -> () + // CHECK: onnx.Constant dense<{{.}}[1.062500e+00, 3.671880e-01], [1.000000e+00, 0x7FC0], [0x7F80, 0.000000e+00]]> + // CHECK-NOT: "onnx.Exp" +} + +// ----- + +// CHECK-LABEL: @test_floor() -> tensor<3x2xbf16> +func.func @test_floor() -> tensor<3x2xbf16> { + // Test Positive, Negative, Zero, NaN, +Inf, -Inf + %0 = onnx.Constant dense<[[12.2, -12.2], [0.0, 0x7FC0], [0x7F80, 0xFF80]]> : tensor<3x2xbf16> + %1 = "onnx.Floor"(%0) : (tensor<3x2xbf16>) -> tensor<3x2xbf16> + "onnx.Return"(%1) : (tensor<3x2xbf16>) -> () + // CHECK: onnx.Constant dense<{{.}}[1.200000e+01, -1.300000e+01], [0.000000e+00, 0x7FC0], [0x7F80, 0xFF80]]> + // CHECK-NOT: "onnx.Floor" +} + +// ----- + +// CHECK-LABEL: @test_log() -> tensor<3x2xbf16> +func.func @test_log() -> tensor<3x2xbf16> { + // Test Positive, Negative, Zero, NaN, +Inf, -Inf + %0 = onnx.Constant dense<[[0.0625, -1.0], [0.0, 0x7FC0], [0x7F80, 0xFF80]]> : tensor<3x2xbf16> + %1 = "onnx.Log"(%0) : (tensor<3x2xbf16>) -> tensor<3x2xbf16> + "onnx.Return"(%1) : (tensor<3x2xbf16>) -> () + // Note: Implementations of log can output either NaN and -NaN for negative and -Inf numbers. + // CHECK: onnx.Constant dense<{{.}}[-2.765630e+00, 0x{{F|7}}FC0], [0xFF80, 0x7FC0], [0x7F80, 0x{{F|7}}FC0]]> + // CHECK-NOT: "onnx.Log" +} + +// ----- + +// CHECK-LABEL: @test_not() -> tensor<2xi1> +func.func @test_not() -> tensor<2xi1> { + // Test Positive, Negative, Zero, NaN, +Inf, -Inf + %0 = onnx.Constant dense<[false, true]> : tensor<2xi1> + %1 = "onnx.Not"(%0) : (tensor<2xi1>) -> tensor<2xi1> + "onnx.Return"(%1) : (tensor<2xi1>) -> () + // CHECK: onnx.Constant dense<[true, false]> + // CHECK-NOT: "onnx.Not" +} + +// ----- + +// CHECK-LABEL: @test_reciprocal() -> tensor<3x2xbf16> +func.func @test_reciprocal() -> tensor<3x2xbf16> { + // Test Positive, Negative, Zero, NaN, +Inf, -Inf + // Results: Positive, Positive, One, NaN, -NaN, -NaN + %0 = onnx.Constant dense<[[0.25, -0.25], [0.0, 0x7FC0], [0x7F80, 0xFF80]]> : tensor<3x2xbf16> + %1 = "onnx.Reciprocal"(%0) : (tensor<3x2xbf16>) -> tensor<3x2xbf16> + "onnx.Return"(%1) : (tensor<3x2xbf16>) -> () + // CHECK: onnx.Constant dense<{{.}}[4.000000e+00, -4.000000e+00], [0x7F80, 0x7FC0], [0.000000e+00, -0.000000e+00]]> + // CHECK-NOT: "onnx.Reciprocal" +} + +// ----- + +// CHECK-LABEL: @test_sin() -> tensor<3x2xf32> +func.func @test_sin() -> tensor<3x2xf32> { + // Test Positive, Negative, Zero, NaN, +Inf, -Inf + // Results: Positive, Positive, One, NaN, +/-NaN, +/-NaN + // Note: Implementations of sin can output either NaN and -NaN for +/-Inf numbers. + %0 = onnx.Constant dense<[[0.625, -0.625], [0.0, 0x7FC00000], [0x7F800000, 0xFF800000]]> : tensor<3x2xf32> + %1 = "onnx.Sin"(%0) : (tensor<3x2xf32>) -> tensor<3x2xf32> + "onnx.Return"(%1) : (tensor<3x2xf32>) -> () + // CHECK: onnx.Constant dense<{{.}}[0.585097253, -0.585097253], [0.000000e+00, 0x7FC00000], [0x{{F|7}}FC00000, 0x{{F|7}}FC00000]]> + // CHECK-NOT: "onnx.Sin" +} + //===----------------------------------------------------------------------===// /// Transpose tests. @@ -379,6 +493,233 @@ func.func @test_div_ones(%arg0 : tensor<1x2xui8>) -> tensor<1x2xui8> { // CHECK: onnx.Return %arg0 : tensor<1x2xui8> } +// ----- + +// CHECK-LABEL: test_div_by_zero() +func.func @test_div_by_zero() -> tensor<2xui32> { + %0 = onnx.Constant dense<[2, 4]> : tensor<2xui32> + %1 = onnx.Constant dense<[0]> : tensor<1xui32> + %2 = "onnx.Div"(%0, %1) : (tensor<2xui32>, tensor<1xui32>) -> tensor<2xui32> + "onnx.Return"(%2) : (tensor<2xui32>) -> () + // The behavior is undefined, so the value don't matter. Just don't crash. + // CHECK-NOT: {{.*}} = "onnx.Div"{{.*}} +} + +// ----- + +//===----------------------------------------------------------------------===// +/// Clip's test + +func.func @test_clip_max_and_min() -> tensor<3x2xbf16> { + // Test Positive Clamped, Negative Clamped, In range, NaN, -Inf, +Inf + %cst = onnx.Constant dense<[[2.125, -2.125], [0.0, 0x7FC0], [0xFF80, 0x7F80]]> : tensor<3x2xbf16> + %min = onnx.Constant {value = dense<-2.0> : tensor} : tensor + %max = onnx.Constant {value = dense<2.0> : tensor} : tensor + %0 = "onnx.Clip"(%cst, %min, %max) : (tensor<3x2xbf16>, tensor, tensor) -> tensor<3x2xbf16> + return %0 : tensor<3x2xbf16> +// CHECK-LABEL: func @test_clip_max_and_min + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[2.000000e+00, -2.000000e+00], [0.000000e+00, 0x7FC0], [-2.000000e+00, 2.000000e+00]]> + // CHECK-NOT: {{.*}} = "onnx.Clip" +} + +// CHECK-LABEL: func @test_clip_no_min +func.func @test_clip_no_min() -> tensor<3x2xbf16> { + // Test Positive Clamped, Negative Clamped, In range, NaN, -Inf, +Inf + %cst = onnx.Constant dense<[[2.1, -2.125], [0.0, 0x7FC0], [0xFF80, 0x7F80]]> : tensor<3x2xbf16> + %none = "onnx.NoValue"() {value} : () -> none + %max = onnx.Constant {value = dense<2.0> : tensor} : tensor + %0 = "onnx.Clip"(%cst, %none, %max) : (tensor<3x2xbf16>, none, tensor) -> tensor<3x2xbf16> + return %0 : tensor<3x2xbf16> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[2.000000e+00, -2.125000e+00], [0.000000e+00, 0x7FC0], [-3.389530e+38, 2.000000e+00]]> + // CHECK-NOT: {{.*}} = "onnx.Clip" +} + +// CHECK-LABEL: func @test_clip_no_max +func.func @test_clip_no_max() -> tensor<3x2xbf16> { + // Test Positive Clamped, Negative Clamped, In range, NaN, -Inf, +Inf + %cst = onnx.Constant dense<[[2.125, -2.125], [0.0, 0x7FC0], [0xFF80, 0x7F80]]> : tensor<3x2xbf16> + %min = onnx.Constant {value = dense<-2.0> : tensor} : tensor + %none = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Clip"(%cst, %min, %none) : (tensor<3x2xbf16>, tensor, none) -> tensor<3x2xbf16> + return %0 : tensor<3x2xbf16> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[2.125000e+00, -2.000000e+00], [0.000000e+00, 0x7FC0], [-2.000000e+00, 3.389530e+38]]> + // CHECK-NOT: {{.*}} = "onnx.Clip" +} + +// CHECK-LABEL: func @test_clip_no_min_no_max +func.func @test_clip_no_min_no_max() -> tensor<3x2xbf16> { + // Test Positive Clamped, Negative Clamped, In range, NaN, -Inf, +Inf + %cst = onnx.Constant dense<[[2.125, -2.125], [0.0, 0x7FC0], [0xFF80, 0x7F80]]> : tensor<3x2xbf16> + %none = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Clip"(%cst, %none, %none) : (tensor<3x2xbf16>, none, none) -> tensor<3x2xbf16> + return %0 : tensor<3x2xbf16> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[2.125000e+00, -2.125000e+00], [0.000000e+00, 0x7FC0], [-3.389530e+38, 3.389530e+38]]> + // CHECK-NOT: {{.*}} = "onnx.Clip" +} + +// ONNX Specification does define what happens when min > max. +// Use numpy's specification as discussed in https://github.com/onnx/onnx/issues/6165 + +// CHECK-LABEL: func @test_clip_min_greater_than_max1 +func.func @test_clip_min_greater_than_max1() -> tensor<3x2xbf16> { + %cst = onnx.Constant dense<[[2.125, -2.125], [0.0, 0x7FC0], [0xFF80, 0x7F80]]> : tensor<3x2xbf16> + %min = onnx.Constant {value = dense<8.0> : tensor} : tensor + %max = onnx.Constant {value = dense<1.0> : tensor} : tensor + %0 = "onnx.Clip"(%cst, %min, %max) : (tensor<3x2xbf16>, tensor, tensor) -> tensor<3x2xbf16> + return %0 : tensor<3x2xbf16> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[1.000000e+00, 1.000000e+00], [1.000000e+00, 0x7FC0], [1.000000e+00, 1.000000e+00]]> + // CHECK-NOT: {{.*}} = "onnx.Clip" +} + +// CHECK-LABEL: func @test_clip_min_greater_than_max2 +func.func @test_clip_min_greater_than_max2() -> tensor<3x2xi32> { + %cst = onnx.Constant dense<[[0, 1], [2, 3], [4, 5]]> : tensor<3x2xi32> + %min = onnx.Constant {value = dense<8> : tensor} : tensor + %max = onnx.Constant {value = dense<1> : tensor} : tensor + %0 = "onnx.Clip"(%cst, %min, %max) : (tensor<3x2xi32>, tensor, tensor) -> tensor<3x2xi32> + return %0 : tensor<3x2xi32> + // CHECK: {{.*}} = onnx.Constant dense<1> + // CHECK-NOT: {{.*}} = "onnx.Clip" +} + + +// ----- + +//===----------------------------------------------------------------------===// +/// Bitwise's test + +// CHECK-LABEL: @test_bitwise_not() -> tensor<4xi32> +func.func @test_bitwise_not() -> tensor<4xi32> { + %0 = onnx.Constant dense<[0, 0xFFFFFFFF, 0xFFFFFFFE, 0x000000FF]> : tensor<4xi32> + %1 = "onnx.BitwiseNot"(%0) : (tensor<4xi32>) -> tensor<4xi32> + "onnx.Return"(%1) : (tensor<4xi32>) -> () + // CHECK: onnx.Constant dense<[-1, 0, 1, -256]> + // CHECK-NOT: "onnx.BitwiseNot" +} + +// ----- + +// CHECK-LABEL: @test_bitwise_and() -> tensor<3xi32> +func.func @test_bitwise_and() -> tensor<3xi32> { + %0 = onnx.Constant dense<[0xFFFFFFFE, 0xFFFF0000, 256]> : tensor<3xi32> + %1 = onnx.Constant dense<0xFFFFFFFF> : tensor + %2 = "onnx.BitwiseAnd"(%0, %1) : (tensor<3xi32>, tensor) -> tensor<3xi32> + "onnx.Return"(%2) : (tensor<3xi32>) -> () + // CHECK: onnx.Constant dense<[-2, -65536, 256]> + // CHECK-NOT: "onnx.BitwiseAnd"{{.*}} +} + +// ----- + +// CHECK-LABEL: @test_bitwise_and() -> tensor<3xi32> +func.func @test_bitwise_and() -> tensor<3xi32> { + %0 = onnx.Constant dense<[-2, 15, 0x0000FFFF]> : tensor<3xi32> + %1 = onnx.Constant dense<[0xFFFFFFFF]> : tensor<1xi32> + %2 = "onnx.BitwiseAnd"(%0, %1) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> + "onnx.Return"(%2) : (tensor<3xi32>) -> () + // CHECK: onnx.Constant dense<[-2, 15, 65535]> + // CHECK-NOT: "onnx.BitwiseAnd"{{.*}} +} + +// ----- + +// CHECK-LABEL: @test_bitwise_and() -> tensor<3xi32> +func.func @test_bitwise_and() -> tensor<3xi32> { + %0 = onnx.Constant dense<[0xFFFFFFFE, 15, 255]> : tensor<3xi32> + %1 = onnx.Constant dense<[0xFFFFFFFF, 0xFFFFFFF0, 0x0000000F]> : tensor<3xi32> + %2 = "onnx.BitwiseAnd"(%0, %1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + "onnx.Return"(%2) : (tensor<3xi32>) -> () + // CHECK: onnx.Constant dense<[-2, 0, 15]> + // CHECK-NOT: "onnx.BitwiseAnd"{{.*}} +} + +// ----- + +// CHECK-LABEL: @test_bitwise_or() -> tensor<3xi32> +func.func @test_bitwise_or() -> tensor<3xi32> { + %0 = onnx.Constant dense<[0xFFFFFFFE, 0xFFFFFFF0, 0xFFFFFF00]> : tensor<3xi32> + %1 = onnx.Constant dense<[0xFFFFFFF1, 1, 2]> : tensor<3xi32> + %2 = "onnx.BitwiseOr"(%0, %1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + "onnx.Return"(%2) : (tensor<3xi32>) -> () + // CHECK: onnx.Constant dense<[-1, -15, -254]> + // CHECK-NOT: "onnx.BitwiseOr"{{.*}} +} + +// ----- + +//===----------------------------------------------------------------------===// +/// Bitwise's test + +// CHECK-LABEL: @test_and() -> tensor<3xi1> +func.func @test_and() -> tensor<3xi1> { + %0 = onnx.Constant dense<[true, true, false]> : tensor<3xi1> + %1 = onnx.Constant dense : tensor + %2 = "onnx.And"(%0, %1) : (tensor<3xi1>, tensor) -> tensor<3xi1> + "onnx.Return"(%2) : (tensor<3xi1>) -> () + // CHECK: onnx.Constant dense<[true, true, false]> + // CHECK-NOT: "onnx.And"{{.*}} +} + +// CHECK-LABEL: @test_and2() -> tensor<3xi1> +func.func @test_and2() -> tensor<3xi1> { + %0 = onnx.Constant dense<[true, true, false]> : tensor<3xi1> + %1 = onnx.Constant dense<[false, true, false]> : tensor<3xi1> + %2 = "onnx.And"(%0, %1) : (tensor<3xi1>, tensor<3xi1>) -> tensor<3xi1> + "onnx.Return"(%2) : (tensor<3xi1>) -> () + // CHECK: onnx.Constant dense<[false, true, false]> + // CHECK-NOT: "onnx.And"{{.*}} +} + +// CHECK-LABEL: @test_and3() -> tensor<3xi1> +func.func @test_and3() -> tensor<3xi1> { + %0 = onnx.Constant dense<[true, true, false]> : tensor<3xi1> + %1 = onnx.Constant dense : tensor + %2 = "onnx.And"(%0, %1) : (tensor<3xi1>, tensor) -> tensor<3xi1> + "onnx.Return"(%2) : (tensor<3xi1>) -> () + // CHECK: onnx.Constant dense + // CHECK-NOT: "onnx.And"{{.*}} +} + +// CHECK-LABEL: @test_or() -> tensor<3xi1> +func.func @test_or() -> tensor<3xi1> { + %0 = onnx.Constant dense<[true, true, false]> : tensor<3xi1> + %1 = onnx.Constant dense : tensor + %2 = "onnx.Or"(%0, %1) : (tensor<3xi1>, tensor) -> tensor<3xi1> + "onnx.Return"(%2) : (tensor<3xi1>) -> () + // CHECK: onnx.Constant dense + // CHECK-NOT: "onnx.Or"{{.*}} +} + +// CHECK-LABEL: @test_or2() -> tensor<3xi1> +func.func @test_or2() -> tensor<3xi1> { + %0 = onnx.Constant dense<[true, true, false]> : tensor<3xi1> + %1 = onnx.Constant dense<[false, true, false]> : tensor<3xi1> + %2 = "onnx.Or"(%0, %1) : (tensor<3xi1>, tensor<3xi1>) -> tensor<3xi1> + "onnx.Return"(%2) : (tensor<3xi1>) -> () + // CHECK: onnx.Constant dense<[true, true, false]> + // CHECK-NOT: "onnx.Or"{{.*}} +} + +// CHECK-LABEL: @test_xor() -> tensor<3xi1> +func.func @test_xor() -> tensor<3xi1> { + %0 = onnx.Constant dense<[true, true, false]> : tensor<3xi1> + %1 = onnx.Constant dense : tensor + %2 = "onnx.Xor"(%0, %1) : (tensor<3xi1>, tensor) -> tensor<3xi1> + "onnx.Return"(%2) : (tensor<3xi1>) -> () + // CHECK: onnx.Constant dense<[false, false, true]> + // CHECK-NOT: "onnx.Xor"{{.*}} +} + +// CHECK-LABEL: @test_xor2() -> tensor<3xi1> +func.func @test_xor2() -> tensor<3xi1> { + %0 = onnx.Constant dense<[true, true, false]> : tensor<3xi1> + %1 = onnx.Constant dense<[false, true, false]> : tensor<3xi1> + %2 = "onnx.Xor"(%0, %1) : (tensor<3xi1>, tensor<3xi1>) -> tensor<3xi1> + "onnx.Return"(%2) : (tensor<3xi1>) -> () + // CHECK: onnx.Constant dense<[true, false, false]> + // CHECK-NOT: "onnx.Xor"{{.*}} +} + //===----------------------------------------------------------------------===// /// Equal test diff --git a/test/perf/PerfHelper.hpp b/test/perf/PerfHelper.hpp index a14d856434..f11970f64a 100644 --- a/test/perf/PerfHelper.hpp +++ b/test/perf/PerfHelper.hpp @@ -23,4 +23,4 @@ void perf_recordFlops(benchmark::State &state, float f); int perf_main(int argc, char **argv); #define PERF_MAIN() \ - int main(int argc, char **argv) { return perf_main(argc, argv); } + int main(int argc, char **argv) { return perf_main((argc), (argv)); } diff --git a/third_party/stablehlo b/third_party/stablehlo index 1b08c4c0e8..92203a9612 160000 --- a/third_party/stablehlo +++ b/third_party/stablehlo @@ -1 +1 @@ -Subproject commit 1b08c4c0e8c893d202bd33c2735562fa0ccc849f +Subproject commit 92203a9612dbdd6681ccd3e65bc61586c4290df1 diff --git a/utils/build-mlir.cmd b/utils/build-mlir.cmd index a306527514..df9a9ea258 100644 --- a/utils/build-mlir.cmd +++ b/utils/build-mlir.cmd @@ -3,13 +3,14 @@ md llvm-project\build cd llvm-project\build call cmake %root_dir%\llvm-project\llvm -G "Ninja" ^ -DCMAKE_INSTALL_PREFIX="%root_dir%\llvm-project\build\install" ^ - -DLLVM_ENABLE_PROJECTS=mlir ^ + -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" ^ -DLLVM_TARGETS_TO_BUILD="host" ^ -DCMAKE_BUILD_TYPE=Release ^ -DLLVM_ENABLE_ASSERTIONS=ON ^ -DLLVM_ENABLE_RTTI=ON ^ -DLLVM_ENABLE_ZLIB=OFF ^ -DLLVM_INSTALL_UTILS=ON ^ + -DENABLE_LIBOMPTARGET=OFF ^ -DLLVM_ENABLE_LIBEDIT=OFF call cmake --build . --config Release diff --git a/utils/build-mlir.sh b/utils/build-mlir.sh index 28f722c8a9..6a9ff0cddc 100644 --- a/utils/build-mlir.sh +++ b/utils/build-mlir.sh @@ -1,11 +1,13 @@ mkdir llvm-project/build cd llvm-project/build + cmake -G Ninja ../llvm \ - -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \ -DLLVM_TARGETS_TO_BUILD="host" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_RTTI=ON \ + -DENABLE_LIBOMPTARGET=OFF \ -DLLVM_ENABLE_LIBEDIT=OFF cmake --build . -- ${MAKEFLAGS} diff --git a/utils/check-onnx-backend-numerical.sh b/utils/check-onnx-backend-numerical.sh index 20e7091bd7..e7dbee6b75 100644 --- a/utils/check-onnx-backend-numerical.sh +++ b/utils/check-onnx-backend-numerical.sh @@ -1,4 +1,4 @@ # Run backend and numerical tests in parallel cd onnx-mlir/build CTEST_PARALLEL_LEVEL=$(sysctl -n hw.logicalcpu) \ -cmake --build . --parallel --target check-onnx-backend check-onnx-backend-input-verification check-onnx-numerical +cmake --build . --parallel --target check-onnx-backend check-onnx-backend-dynamic check-onnx-backend-input-verification check-onnx-numerical diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index bcc1119510..4965b58726 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout 6461b921fd06b1c812f1172685b8b7edc0608af7 && cd .. +cd llvm-project && git checkout 60a7d33106d3cd645d3100a8a935a1e3837f885d && cd .. diff --git a/utils/install-onnx-mlir.sh b/utils/install-onnx-mlir.sh index 001c87d0b6..789d1e4d54 100755 --- a/utils/install-onnx-mlir.sh +++ b/utils/install-onnx-mlir.sh @@ -4,11 +4,15 @@ mkdir onnx-mlir/build && cd onnx-mlir/build if [[ -z "$pythonLocation" ]]; then cmake -G Ninja \ -DCMAKE_CXX_COMPILER=/usr/bin/c++ \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_DIR=${MLIR_DIR} \ .. else cmake -G Ninja \ -DCMAKE_CXX_COMPILER=/usr/bin/c++ \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ -DPython3_ROOT_DIR=$pythonLocation \ -DMLIR_DIR=${MLIR_DIR} \ ..