diff --git a/mlir/include/mlir/IR/AttrTypeSubElements.h b/mlir/include/mlir/IR/AttrTypeSubElements.h index 3105040b876317c..234767deea00afd 100644 --- a/mlir/include/mlir/IR/AttrTypeSubElements.h +++ b/mlir/include/mlir/IR/AttrTypeSubElements.h @@ -16,6 +16,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Visitors.h" +#include "mlir/Support/CyclicReplacerCache.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include @@ -116,9 +117,21 @@ class AttrTypeWalker { /// AttrTypeReplacer //===----------------------------------------------------------------------===// -/// This class provides a utility for replacing attributes/types, and their sub -/// elements. Multiple replacement functions may be registered. -class AttrTypeReplacer { +namespace detail { + +/// This class provides a base utility for replacing attributes/types, and their +/// sub elements. Multiple replacement functions may be registered. +/// +/// This base utility is uncached. Users can choose between two cached versions +/// of this replacer: +/// * For non-cyclic replacer logic, use `AttrTypeReplacer`. +/// * For cyclic replacer logic, use `CyclicAttrTypeReplacer`. +/// +/// Concrete implementations implement the following `replace` entry functions: +/// * Attribute replace(Attribute attr); +/// * Type replace(Type type); +template +class AttrTypeReplacerBase { public: //===--------------------------------------------------------------------===// // Application @@ -139,12 +152,6 @@ class AttrTypeReplacer { bool replaceLocs = false, bool replaceTypes = false); - /// Replace the given attribute/type, and recursively replace any sub - /// elements. Returns either the new attribute/type, or nullptr in the case of - /// failure. - Attribute replace(Attribute attr); - Type replace(Type type); - //===--------------------------------------------------------------------===// // Registration //===--------------------------------------------------------------------===// @@ -206,21 +213,114 @@ class AttrTypeReplacer { }); } -private: - /// Internal implementation of the `replace` methods above. - template - T replaceImpl(T element, ReplaceFns &replaceFns); - - /// Replace the sub elements of the given interface. - template - T replaceSubElements(T interface); +protected: + /// Invokes the registered replacement functions from most recently registered + /// to least recently registered until a successful replacement is returned. + /// Unless skipping is requested, invokes `replace` on sub-elements of the + /// current attr/type. + Attribute replaceBase(Attribute attr); + Type replaceBase(Type type); +private: /// The set of replacement functions that map sub elements. std::vector> attrReplacementFns; std::vector> typeReplacementFns; +}; + +} // namespace detail + +/// This is an attribute/type replacer that is naively cached. It is best used +/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any +/// re-occurrence of an in-progress element will be skipped. +class AttrTypeReplacer : public detail::AttrTypeReplacerBase { +public: + Attribute replace(Attribute attr); + Type replace(Type type); + +private: + /// Shared concrete implementation of the public `replace` functions. Invokes + /// `replaceBase` with caching. + template + T cachedReplaceImpl(T element); + + // Stores the opaque pointer of an attribute or type. + DenseMap cache; +}; + +/// This is an attribute/type replacer that supports custom handling of cycles +/// in the replacer logic. In addition to registering replacer functions, it +/// allows registering cycle-breaking functions in the same style. +class CyclicAttrTypeReplacer + : public detail::AttrTypeReplacerBase { +public: + CyclicAttrTypeReplacer(); - /// The set of cached mappings for attributes/types. - DenseMap attrTypeMap; + //===--------------------------------------------------------------------===// + // Application + //===--------------------------------------------------------------------===// + + Attribute replace(Attribute attr); + Type replace(Type type); + + //===--------------------------------------------------------------------===// + // Registration + //===--------------------------------------------------------------------===// + + /// A cycle-breaking function. This is invoked if the same element is asked to + /// be replaced again when the first instance of it is still being replaced. + /// This function must not perform any more recursive `replace` calls. + /// If it is able to break the cycle, it should return a replacement result. + /// Otherwise, it can return std::nullopt to defer cycle breaking to the next + /// repeated element. However, the user must guarantee that, in any possible + /// cycle, there always exists at least one element that can break the cycle. + template + using CycleBreakerFn = std::function(T)>; + + /// Register a cycle-breaking function. + /// When breaking cycles, the mostly recently added cycle-breaking functions + /// will be invoked first. + void addCycleBreaker(CycleBreakerFn fn); + void addCycleBreaker(CycleBreakerFn fn); + + /// Register a cycle-breaking function that doesn't match the default + /// signature. + template >::template arg_t<0>, + typename BaseT = std::conditional_t, + Attribute, Type>> + std::enable_if_t> addCycleBreaker(FnT &&callback) { + addCycleBreaker([callback = std::forward(callback)]( + BaseT base) -> std::optional { + if (auto derived = dyn_cast(base)) + return callback(derived); + return std::nullopt; + }); + } + +private: + /// Invokes the registered cycle-breaker functions from most recently + /// registered to least recently registered until a successful result is + /// returned. + std::optional breakCycleImpl(void *element); + + /// Shared concrete implementation of the public `replace` functions. + template + T cachedReplaceImpl(T element); + + /// The set of registered cycle-breaker functions. + std::vector> attrCycleBreakerFns; + std::vector> typeCycleBreakerFns; + + /// A cache of previously-replaced attr/types. + /// The key of the cache is the opaque value of an AttrOrType. Using + /// AttrOrType allows distinguishing between the two types when invoking + /// cycle-breakers. Using its opaque value avoids the cyclic dependency issue + /// of directly using `AttrOrType` to instantiate the cache. + /// The value of the cache is just the opaque value of the attr/type itself + /// (not the PointerUnion). + using AttrOrType = PointerUnion; + CyclicReplacerCache cache; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AttrTypeSubElements.cpp b/mlir/lib/IR/AttrTypeSubElements.cpp index 79b04966be6eb19..783236ed3a9df67 100644 --- a/mlir/lib/IR/AttrTypeSubElements.cpp +++ b/mlir/lib/IR/AttrTypeSubElements.cpp @@ -67,22 +67,28 @@ WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) { } //===----------------------------------------------------------------------===// -/// AttrTypeReplacer +/// AttrTypeReplacerBase //===----------------------------------------------------------------------===// -void AttrTypeReplacer::addReplacement(ReplaceFn fn) { +template +void detail::AttrTypeReplacerBase::addReplacement( + ReplaceFn fn) { attrReplacementFns.emplace_back(std::move(fn)); } -void AttrTypeReplacer::addReplacement(ReplaceFn fn) { + +template +void detail::AttrTypeReplacerBase::addReplacement( + ReplaceFn fn) { typeReplacementFns.push_back(std::move(fn)); } -void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs, - bool replaceLocs, bool replaceTypes) { +template +void detail::AttrTypeReplacerBase::replaceElementsIn( + Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { // Functor that replaces the given element if the new value is different, // otherwise returns nullptr. auto replaceIfDifferent = [&](auto element) { - auto replacement = replace(element); + auto replacement = static_cast(this)->replace(element); return (replacement && replacement != element) ? replacement : nullptr; }; @@ -127,17 +133,16 @@ void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs, } } -void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op, - bool replaceAttrs, - bool replaceLocs, - bool replaceTypes) { +template +void detail::AttrTypeReplacerBase::recursivelyReplaceElementsIn( + Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { op->walk([&](Operation *nestedOp) { replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes); }); } -template -static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, +template +static void updateSubElementImpl(T element, Replacer &replacer, SmallVectorImpl &newElements, FailureOr &changed) { // Bail early if we failed at any point. @@ -160,18 +165,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, } } -template -T AttrTypeReplacer::replaceSubElements(T interface) { +template +static T replaceSubElements(T interface, Replacer &replacer) { // Walk the current sub-elements, replacing them as necessary. SmallVector newAttrs; SmallVector newTypes; FailureOr changed = false; interface.walkImmediateSubElements( [&](Attribute element) { - updateSubElementImpl(element, *this, newAttrs, changed); + updateSubElementImpl(element, replacer, newAttrs, changed); }, [&](Type element) { - updateSubElementImpl(element, *this, newTypes, changed); + updateSubElementImpl(element, replacer, newTypes, changed); }); if (failed(changed)) return nullptr; @@ -184,13 +189,9 @@ T AttrTypeReplacer::replaceSubElements(T interface) { } /// Shared implementation of replacing a given attribute or type element. -template -T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) { - const void *opaqueElement = element.getAsOpaquePointer(); - auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement); - if (!inserted) - return T::getFromOpaquePointer(it->second); - +template +static T replaceElementImpl(T element, ReplaceFns &replaceFns, + Replacer &replacer) { T result = element; WalkResult walkResult = WalkResult::advance(); for (auto &replaceFn : llvm::reverse(replaceFns)) { @@ -202,29 +203,114 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) { // If an error occurred, return nullptr to indicate failure. if (walkResult.wasInterrupted() || !result) { - attrTypeMap[opaqueElement] = nullptr; return nullptr; } // Handle replacing sub-elements if this element is also a container. if (!walkResult.wasSkipped()) { // Replace the sub elements of this element, bailing if we fail. - if (!(result = replaceSubElements(result))) { - attrTypeMap[opaqueElement] = nullptr; + if (!(result = replaceSubElements(result, replacer))) { return nullptr; } } - attrTypeMap[opaqueElement] = result.getAsOpaquePointer(); + return result; +} + +template +Attribute detail::AttrTypeReplacerBase::replaceBase(Attribute attr) { + return replaceElementImpl(attr, attrReplacementFns, + *static_cast(this)); +} + +template +Type detail::AttrTypeReplacerBase::replaceBase(Type type) { + return replaceElementImpl(type, typeReplacementFns, + *static_cast(this)); +} + +//===----------------------------------------------------------------------===// +/// AttrTypeReplacer +//===----------------------------------------------------------------------===// + +template class detail::AttrTypeReplacerBase; + +template +T AttrTypeReplacer::cachedReplaceImpl(T element) { + const void *opaqueElement = element.getAsOpaquePointer(); + auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement); + if (!inserted) + return T::getFromOpaquePointer(it->second); + + T result = replaceBase(element); + + cache[opaqueElement] = result.getAsOpaquePointer(); return result; } Attribute AttrTypeReplacer::replace(Attribute attr) { - return replaceImpl(attr, attrReplacementFns); + return cachedReplaceImpl(attr); } -Type AttrTypeReplacer::replace(Type type) { - return replaceImpl(type, typeReplacementFns); +Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); } + +//===----------------------------------------------------------------------===// +/// CyclicAttrTypeReplacer +//===----------------------------------------------------------------------===// + +template class detail::AttrTypeReplacerBase; + +CyclicAttrTypeReplacer::CyclicAttrTypeReplacer() + : cache([&](void *attr) { return breakCycleImpl(attr); }) {} + +void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn fn) { + attrCycleBreakerFns.emplace_back(std::move(fn)); +} + +void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn fn) { + typeCycleBreakerFns.emplace_back(std::move(fn)); +} + +template +T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) { + void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue(); + CyclicReplacerCache::CacheEntry cacheEntry = + cache.lookupOrInit(opaqueTaggedElement); + if (auto resultOpt = cacheEntry.get()) + return T::getFromOpaquePointer(*resultOpt); + + T result = replaceBase(element); + + cacheEntry.resolve(result.getAsOpaquePointer()); + return result; +} + +Attribute CyclicAttrTypeReplacer::replace(Attribute attr) { + return cachedReplaceImpl(attr); +} + +Type CyclicAttrTypeReplacer::replace(Type type) { + return cachedReplaceImpl(type); +} + +std::optional +CyclicAttrTypeReplacer::breakCycleImpl(void *element) { + AttrOrType attrType = AttrOrType::getFromOpaqueValue(element); + if (auto attr = dyn_cast(attrType)) { + for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) { + if (std::optional newRes = cyclicReplaceFn(attr)) { + return newRes->getAsOpaquePointer(); + } + } + } else { + auto type = dyn_cast(attrType); + for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) { + if (std::optional newRes = cyclicReplaceFn(type)) { + return newRes->getAsOpaquePointer(); + } + } + } + return std::nullopt; } //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/IR/AttrTypeReplacerTest.cpp b/mlir/unittests/IR/AttrTypeReplacerTest.cpp new file mode 100644 index 000000000000000..c7b42eb267c7ade --- /dev/null +++ b/mlir/unittests/IR/AttrTypeReplacerTest.cpp @@ -0,0 +1,231 @@ +//===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AttrTypeSubElements.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// CyclicAttrTypeReplacer +//===----------------------------------------------------------------------===// + +TEST(CyclicAttrTypeReplacerTest, testNoRecursion) { + MLIRContext ctx; + + CyclicAttrTypeReplacer replacer; + replacer.addReplacement([&](BoolAttr b) { + return StringAttr::get(&ctx, b.getValue() ? "true" : "false"); + }); + + EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)), + StringAttr::get(&ctx, "true")); + EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)), + StringAttr::get(&ctx, "false")); + EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)), + mlir::UnitAttr::get(&ctx)); +} + +TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) { + MLIRContext ctx; + Builder b(&ctx); + + CyclicAttrTypeReplacer replacer; + // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ... + replacer.addReplacement([&](IntegerAttr attr) { + return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3)); + }); + // The first repeat of any integer attr is pruned into a unit attr. + replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); }); + + // No recursion case. + EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)), + mlir::UnitAttr::get(&ctx)); + // Starting at 0. + EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx)); + // Starting at 2. + EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx)); +} + +//===----------------------------------------------------------------------===// +// CyclicAttrTypeReplacerTest: ChainRecursion +//===----------------------------------------------------------------------===// + +class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test { +public: + CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) { + // IntegerType + // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>. + // This will create a chain of infinite length without recursion pruning. + replacer.addReplacement([&](mlir::IntegerType intType) { + ++invokeCount; + return b.getFunctionType( + {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)}); + }); + } + + void setBaseCase(std::optional pruneAt) { + replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) { + return (!pruneAt || intType.getWidth() == *pruneAt) + ? std::make_optional(b.getIndexType()) + : std::nullopt; + }); + } + + Type getFunctionTypeChain(unsigned N) { + Type type = b.getIndexType(); + for (unsigned i = 0; i < N; i++) + type = b.getFunctionType({}, type); + return type; + }; + + MLIRContext ctx; + Builder b; + CyclicAttrTypeReplacer replacer; + int invokeCount = 0; +}; + +TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) { + setBaseCase(std::nullopt); + + // No recursion case. + EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType()); + EXPECT_EQ(invokeCount, 0); + + // Starting at 0. Cycle length is 3. + invokeCount = 0; + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), + getFunctionTypeChain(3)); + EXPECT_EQ(invokeCount, 3); + + // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. + invokeCount = 0; + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeChain(5)); + EXPECT_EQ(invokeCount, 2); +} + +TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere1) { + setBaseCase(std::nullopt); + + // Starting at 1. Cycle length is 3. + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeChain(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific0) { + setBaseCase(0); + + // Starting at 0. Cycle length is 3. + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), + getFunctionTypeChain(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific1) { + setBaseCase(0); + + // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeChain(5)); + EXPECT_EQ(invokeCount, 5); +} + +//===----------------------------------------------------------------------===// +// CyclicAttrTypeReplacerTest: BranchingRecusion +//===----------------------------------------------------------------------===// + +class CyclicAttrTypeReplacerBranchingRecusionPruningTest + : public ::testing::Test { +public: + CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx) { + // IntegerType + // ==> FunctionType< + // IntegerType< width = (N+1) % 3> => + // IntegerType< width = (N+1) % 3>>. + // This will create a binary tree of infinite depth without pruning. + replacer.addReplacement([&](mlir::IntegerType intType) { + ++invokeCount; + Type child = mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3); + return b.getFunctionType({child}, {child}); + }); + } + + void setBaseCase(std::optional pruneAt) { + replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) { + return (!pruneAt || intType.getWidth() == *pruneAt) + ? std::make_optional(b.getIndexType()) + : std::nullopt; + }); + } + + Type getFunctionTypeTree(unsigned N) { + Type type = b.getIndexType(); + for (unsigned i = 0; i < N; i++) + type = b.getFunctionType(type, type); + return type; + }; + + MLIRContext ctx; + Builder b; + CyclicAttrTypeReplacer replacer; + int invokeCount = 0; +}; + +TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere0) { + setBaseCase(std::nullopt); + + // No recursion case. + EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType()); + EXPECT_EQ(invokeCount, 0); + + // Starting at 0. Cycle length is 3. + invokeCount = 0; + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), + getFunctionTypeTree(3)); + // Since both branches are identical, this should incur linear invocations + // of the replacement function instead of exponential. + EXPECT_EQ(invokeCount, 3); + + // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. + invokeCount = 0; + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeTree(5)); + EXPECT_EQ(invokeCount, 2); +} + +TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere1) { + setBaseCase(std::nullopt); + + // Starting at 1. Cycle length is 3. + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeTree(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific0) { + setBaseCase(0); + + // Starting at 0. Cycle length is 3. + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), + getFunctionTypeTree(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific1) { + setBaseCase(0); + + // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeTree(5)); + EXPECT_EQ(invokeCount, 5); +} diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 71f8f449756ec01..05cb36e19031635 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_unittest(MLIRIRTests AffineExprTest.cpp AffineMapTest.cpp AttributeTest.cpp + AttrTypeReplacerTest.cpp DialectTest.cpp InterfaceTest.cpp IRMapping.cpp