diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 1dc0398494dccf..20f019666a2e6a 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -15,6 +15,7 @@ #ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H #define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt index 7d40caebe1e053..a0096e5f299d59 100644 --- a/mlir/lib/Dialect/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -4,7 +4,9 @@ add_mlir_library(MLIRDialectUtils StructuredOpsUtils.cpp StaticValueUtils.cpp + DEPENDS + MLIRDialectUtilsIncGen + LINK_LIBS PUBLIC - MLIRArithUtils MLIRIR ) diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 0c8a88da789e26..1e8197e1094424 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/MathExtras.h" @@ -18,13 +17,10 @@ namespace mlir { bool isZeroIndex(OpFoldResult v) { if (!v) return false; - if (auto attr = llvm::dyn_cast_if_present(v)) { - IntegerAttr intAttr = dyn_cast(attr); - return intAttr && intAttr.getValue().isZero(); - } - if (auto cst = v.get().getDefiningOp()) - return cst.value() == 0; - return false; + std::optional constint = getConstantIntValue(v); + if (!constint) + return false; + return *constint == 0; } std::tuple, SmallVector,