From bcc3edaa9b80c1ee3e87ad15894382a065bef9ad Mon Sep 17 00:00:00 2001 From: Billy Zhu Date: Tue, 9 Jul 2024 10:26:05 -0700 Subject: [PATCH 1/5] add replacer cache and gtests --- .../mlir/Support/CyclicReplacerCache.h | 277 ++++++++++ mlir/unittests/Support/CMakeLists.txt | 1 + .../Support/CyclicReplacerCacheTest.cpp | 472 ++++++++++++++++++ 3 files changed, 750 insertions(+) create mode 100644 mlir/include/mlir/Support/CyclicReplacerCache.h create mode 100644 mlir/unittests/Support/CyclicReplacerCacheTest.cpp diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h new file mode 100644 index 000000000000000..9a703676fff11e4 --- /dev/null +++ b/mlir/include/mlir/Support/CyclicReplacerCache.h @@ -0,0 +1,277 @@ +//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains helper classes for caching replacer-like functions that +// map values between two domains. They are able to handle replacer logic that +// contains self-recursion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_CACHINGREPLACER_H +#define MLIR_SUPPORT_CACHINGREPLACER_H + +#include "mlir/IR/Visitors.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include + +namespace mlir { + +//===----------------------------------------------------------------------===// +// CyclicReplacerCache +//===----------------------------------------------------------------------===// + +/// A cache for replacer-like functions that map values between two domains. The +/// difference compared to just using a map to cache in-out pairs is that this +/// class is able to handle replacer logic that is self-recursive (and thus may +/// cause infinite recursion in the naive case). +/// +/// This class provides a hook for the user to perform cycle pruning when a +/// cycle is identified, and is able to perform context-sensitive caching so +/// that the replacement result for an input that is part of a pruned cycle can +/// be distinct from the replacement result for the same input when it is not +/// part of a cycle. +/// +/// In addition, this class allows deferring cycle pruning until specific inputs +/// are repeated. This is useful for cases where not all elements in a cycle can +/// perform pruning. The user still must guarantee that at least one element in +/// any given cycle can perform pruning. Even if not, an assertion will +/// eventually be tripped instead of infinite recursion (the run-time is +/// linearly bounded by the maximum cycle length of its input). +template +class CyclicReplacerCache { +public: + /// User-provided replacement function & cycle-breaking functions. + /// The cycle-breaking function must not make any more recursive invocations + /// to this cached replacer. + using CycleBreakerFn = std::function(const InT &)>; + + CyclicReplacerCache() = delete; + CyclicReplacerCache(CycleBreakerFn cycleBreaker) + : cycleBreaker(std::move(cycleBreaker)) {} + + /// A possibly unresolved cache entry. + /// If unresolved, the entry must be resolved before it goes out of scope. + struct CacheEntry { + public: + ~CacheEntry() { assert(result && "unresovled cache entry"); } + + /// Check whether this node was repeated during recursive replacements. + /// This only makes sense to be called after all recursive replacements are + /// completed and the current element has resurfaced to the top of the + /// replacement stack. + bool wasRepeated() const { + // If the top frame includes itself as a dependency, then it must have + // been repeated. + ReplacementFrame &currFrame = cache.replacementStack.back(); + size_t currFrameIndex = cache.replacementStack.size() - 1; + return currFrame.dependentFrames.count(currFrameIndex); + } + + /// Resolve an unresolved cache entry by providing the result to be stored + /// in the cache. + void resolve(OutT result) { + assert(!this->result && "cache entry already resolved"); + this->result = result; + cache.finalizeReplacement(element, result); + } + + /// Get the resolved result if one exists. + std::optional get() { return result; } + + private: + friend class CyclicReplacerCache; + CacheEntry() = delete; + CacheEntry(CyclicReplacerCache &cache, InT element, + std::optional result = std::nullopt) + : cache(cache), element(element), result(result) {} + + CyclicReplacerCache &cache; + InT element; + std::optional result; + }; + + /// Lookup the cache for a pre-calculated replacement for `element`. + /// If one exists, a resolved CacheEntry will be returned. Otherwise, an + /// unresolved CacheEntry will be returned, and the caller must resolve it + /// with the calculated replacement so it can be registered in the cache for + /// future use. + /// Multiple unresolved CacheEntries may be retrieved. However, any unresolved + /// CacheEntries that are returned must be resolved in reverse order of + /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and + /// the first retrieved CacheEntry must be resolved last. This should be + /// natural when used as a stack / inside recursion. + CacheEntry lookupOrInit(const InT &element); + +private: + /// Register the replacement in the cache and update the replacementStack. + void finalizeReplacement(const InT &element, const OutT &result); + + CycleBreakerFn cycleBreaker; + DenseMap standaloneCache; + + struct DependentReplacement { + OutT replacement; + /// The highest replacement frame index that this cache entry is dependent + /// on. + size_t highestDependentFrame; + }; + DenseMap dependentCache; + + struct ReplacementFrame { + /// The set of elements that is only legal while under this current frame. + /// They need to be removed from the cache when this frame is popped off the + /// replacement stack. + DenseSet dependingReplacements; + /// The set of frame indices that this current frame's replacement is + /// dependent on, ordered from highest to lowest. + std::set> dependentFrames; + }; + /// Every element currently in the progress of being replaced pushes a frame + /// onto this stack. + SmallVector replacementStack; + /// Maps from each input element to its indices on the replacement stack. + DenseMap> cyclicElementFrame; + /// If set to true, we are currently asking an element to break a cycle. No + /// more recursive invocations is allowed while this is true (the replacement + /// stack can no longer grow). + bool resolvingCycle = false; +}; + +template +typename CyclicReplacerCache::CacheEntry +CyclicReplacerCache::lookupOrInit(const InT &element) { + assert(!resolvingCycle && + "illegal recursive invocation while breaking cycle"); + + if (auto it = standaloneCache.find(element); it != standaloneCache.end()) + return CacheEntry(*this, element, it->second); + + if (auto it = dependentCache.find(element); it != dependentCache.end()) { + // pdate the current top frame (the element that invoked this current + // replacement) to include any dependencies the cache entry had. + ReplacementFrame &currFrame = replacementStack.back(); + currFrame.dependentFrames.insert(it->second.highestDependentFrame); + return CacheEntry(*this, element, it->second.replacement); + } + + auto [it, inserted] = cyclicElementFrame.try_emplace(element); + if (!inserted) { + // This is a repeat of a known element. Try to break cycle here. + resolvingCycle = true; + std::optional result = cycleBreaker(element); + resolvingCycle = false; + if (result) { + // Cycle was broken. + size_t dependentFrame = it->second.back(); + dependentCache[element] = {*result, dependentFrame}; + ReplacementFrame &currFrame = replacementStack.back(); + // If this is a repeat, there is no replacement frame to pop. Mark the top + // frame as being dependent on this element. + currFrame.dependentFrames.insert(dependentFrame); + + return CacheEntry(*this, element, *result); + } + + // Cycle could not be broken. + // A legal setup must ensure at least one element of each cycle can break + // cycles. Under this setup, each element can be seen at most twice before + // the cycle is broken. If we see an element more than twice, we know this + // is an illegal setup. + assert(it->second.size() <= 2 && "illegal 3rd repeat of input"); + } + + // Otherwise, either this is the first time we see this element, or this + // element could not break this cycle. + it->second.push_back(replacementStack.size()); + replacementStack.emplace_back(); + + return CacheEntry(*this, element); +} + +template +void CyclicReplacerCache::finalizeReplacement(const InT &element, + const OutT &result) { + ReplacementFrame &currFrame = replacementStack.back(); + // With the conclusion of this replacement frame, the current element is no + // longer a dependent element. + currFrame.dependentFrames.erase(replacementStack.size() - 1); + + auto prevLayerIter = ++replacementStack.rbegin(); + if (prevLayerIter == replacementStack.rend()) { + // If this is the last frame, there should be zero dependents. + assert(currFrame.dependentFrames.empty() && + "internal error: top-level dependent replacement"); + // Cache standalone result. + standaloneCache[element] = result; + } else if (currFrame.dependentFrames.empty()) { + // Cache standalone result. + standaloneCache[element] = result; + } else { + // Cache dependent result. + size_t highestDependentFrame = *currFrame.dependentFrames.begin(); + dependentCache[element] = {result, highestDependentFrame}; + + // Otherwise, the previous frame inherits the same dependent frames. + prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(), + currFrame.dependentFrames.end()); + + // Mark this current replacement as a depending replacement on the closest + // dependent frame. + replacementStack[highestDependentFrame].dependingReplacements.insert( + element); + } + + // All depending replacements in the cache must be purged. + for (InT key : currFrame.dependingReplacements) + dependentCache.erase(key); + + replacementStack.pop_back(); + auto it = cyclicElementFrame.find(element); + it->second.pop_back(); + if (it->second.empty()) + cyclicElementFrame.erase(it); +} + +//===----------------------------------------------------------------------===// +// CachedCyclicReplacer +//===----------------------------------------------------------------------===// + +/// A helper class for cases where the input/output types of the replacer +/// function is identical to the types stored in the cache. This class wraps +/// the user-provided replacer function, and can be used in place of the user +/// function. +template +class CachedCyclicReplacer { +public: + using ReplacerFn = std::function; + using CycleBreakerFn = + typename CyclicReplacerCache::CycleBreakerFn; + + CachedCyclicReplacer() = delete; + CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker) + : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {} + + OutT operator()(const InT &element) { + auto cacheEntry = cache.lookupOrInit(element); + if (std::optional result = cacheEntry.get()) + return *result; + + OutT result = replacer(element); + cacheEntry.resolve(result); + return result; + } + +private: + ReplacerFn replacer; + CyclicReplacerCache cache; +}; + +} // namespace mlir + +#endif // MLIR_SUPPORT_CACHINGREPLACER_H diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt index 1dbf072bcbbfd10..ec79a1c64090924 100644 --- a/mlir/unittests/Support/CMakeLists.txt +++ b/mlir/unittests/Support/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_unittest(MLIRSupportTests + CyclicReplacerCacheTest.cpp IndentedOstreamTest.cpp StorageUniquerTest.cpp ) diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp new file mode 100644 index 000000000000000..23748e29765cbd8 --- /dev/null +++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp @@ -0,0 +1,472 @@ +//===- CyclicReplacerCacheTest.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/CyclicReplacerCache.h" +#include "llvm/ADT/SetVector.h" +#include "gmock/gmock.h" +#include +#include + +using namespace mlir; + +TEST(CachedCyclicReplacerTest, testNoRecursion) { + CachedCyclicReplacer replacer( + /*replacer=*/[](int n) { return static_cast(n); }, + /*cycleBreaker=*/[](int n) { return std::nullopt; }); + + EXPECT_EQ(replacer(3), true); + EXPECT_EQ(replacer(0), false); +} + +TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) { + // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ... + CachedCyclicReplacer replacer( + /*replacer=*/[&](int n) { return replacer((n + 1) % 3); }, + /*cycleBreaker=*/[&](int n) { return -1; }); + + // Starting at 0. + EXPECT_EQ(replacer(0), -1); + // Starting at 2. + EXPECT_EQ(replacer(2), -1); +} + +//===----------------------------------------------------------------------===// +// CachedCyclicReplacer: ChainRecursion +//===----------------------------------------------------------------------===// + +/// This set of tests uses a replacer function that maps ints into vectors of +/// ints. +/// +/// The replacement result for input `n` is the replacement result of `(n+1)%3` +/// appended with an element `42`. Theoretically, this will produce an +/// infinitely long vector. The cycle-breaker function prunes this infinite +/// recursion in the replacer logic by returning an empty vector upon the first +/// re-occurrence of an input value. +class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test { +public: + // N ==> (N+1) % 3 + // This will create a chain of infinite length without recursion pruning. + CachedCyclicReplacerChainRecursionPruningTest() + : replacer( + [&](int n) { + ++invokeCount; + std::vector result = replacer((n + 1) % 3); + result.push_back(42); + return result; + }, + [&](int n) -> std::optional> { + return baseCase.value_or(n) == n + ? std::make_optional(std::vector{}) + : std::nullopt; + }) {} + + std::vector getChain(unsigned N) { return std::vector(N, 42); }; + + CachedCyclicReplacer> replacer; + int invokeCount = 0; + std::optional baseCase = std::nullopt; +}; + +TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) { + // Starting at 0. Cycle length is 3. + EXPECT_EQ(replacer(0), getChain(3)); + EXPECT_EQ(invokeCount, 3); + + // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. + invokeCount = 0; + EXPECT_EQ(replacer(1), getChain(5)); + EXPECT_EQ(invokeCount, 2); + + // Starting at 2. Cycle length is 4. Entire result is cached. + invokeCount = 0; + EXPECT_EQ(replacer(2), getChain(4)); + EXPECT_EQ(invokeCount, 0); +} + +TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) { + // Starting at 1. Cycle length is 3. + EXPECT_EQ(replacer(1), getChain(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) { + baseCase = 0; + + // Starting at 0. Cycle length is 3. + EXPECT_EQ(replacer(0), getChain(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) { + baseCase = 0; + + // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). + EXPECT_EQ(replacer(1), getChain(5)); + EXPECT_EQ(invokeCount, 5); + + // Starting at 0. Cycle length is 3. Entire result is cached. + invokeCount = 0; + EXPECT_EQ(replacer(0), getChain(3)); + EXPECT_EQ(invokeCount, 0); +} + +//===----------------------------------------------------------------------===// +// CachedCyclicReplacer: GraphReplacement +//===----------------------------------------------------------------------===// + +/// This set of tests uses a replacer function that maps from cyclic graphs to +/// trees, pruning out cycles in the process. +/// +/// It consists of two helper classes: +/// - Graph +/// - A directed graph where nodes are non-negative integers. +/// - PrunedGraph +/// - A Graph where edges that used to cause cycles are now represented with +/// an indirection (a recursionId). +class CachedCyclicReplacerGraphReplacement : public ::testing::Test { +public: + /// A directed graph where nodes are non-negative integers. + struct Graph { + using Node = int64_t; + + /// Use ordered containers for deterministic output. + /// Nodes without outgoing edges are considered nonexistent. + std::map> edges; + + void addEdge(Node src, Node sink) { edges[src].insert(sink); } + + bool isCyclic() const { + DenseSet visited; + for (Node root : llvm::make_first_range(edges)) { + if (visited.contains(root)) + continue; + + SetVector path; + SmallVector workstack; + workstack.push_back(root); + while (!workstack.empty()) { + Node curr = workstack.back(); + workstack.pop_back(); + + if (curr < 0) { + // A negative node signals the end of processing all of this node's + // children. Remove self from path. + assert(path.back() == -curr && "internal inconsistency"); + path.pop_back(); + continue; + } + + if (path.contains(curr)) + return true; + + visited.insert(curr); + auto edgesIter = edges.find(curr); + if (edgesIter == edges.end() || edgesIter->second.empty()) + continue; + + path.insert(curr); + // Push negative node to signify recursion return. + workstack.push_back(-curr); + workstack.insert(workstack.end(), edgesIter->second.begin(), + edgesIter->second.end()); + } + } + return false; + } + + /// Deterministic output for testing. + std::string serialize() const { + std::ostringstream oss; + for (const auto &[src, neighbors] : edges) { + oss << src << ":"; + for (Graph::Node neighbor : neighbors) + oss << " " << neighbor; + oss << "\n"; + } + return oss.str(); + } + }; + + /// A Graph where edges that used to cause cycles (back-edges) are now + /// represented with an indirection (a recursionId). + /// + /// In addition to each node being an integer, each node also tracks the + /// original integer id it had in the original graph. This way for every + /// back-edge, we can represent it as pointing to a new instance of the + /// original node. Then we mark the original node and the new instance with + /// a new unique recursionId to indicate that they're supposed to be the same + /// graph. + struct PrunedGraph { + using Node = Graph::Node; + struct NodeInfo { + Graph::Node originalId; + // A negative recursive index means not recursive. + int64_t recursionId; + }; + + /// Add a regular non-recursive-self node. + Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) { + Node id = nextConnectionId++; + info[id] = {originalId, recursionIndex}; + return id; + } + /// Add a recursive-self-node, i.e. a duplicate of the original node that is + /// meant to represent an indirection to it. + std::pair addRecursiveSelfNode(Graph::Node originalId) { + return {addNode(originalId, nextRecursionId), nextRecursionId++}; + } + void addEdge(Node src, Node sink) { connections.addEdge(src, sink); } + + /// Deterministic output for testing. + std::string serialize() const { + std::ostringstream oss; + oss << "nodes\n"; + for (const auto &[nodeId, nodeInfo] : info) { + oss << nodeId << ": n" << nodeInfo.originalId; + if (nodeInfo.recursionId >= 0) + oss << '<' << nodeInfo.recursionId << '>'; + oss << "\n"; + } + oss << "edges\n"; + oss << connections.serialize(); + return oss.str(); + } + + bool isCyclic() const { return connections.isCyclic(); } + + private: + Graph connections; + int64_t nextRecursionId = 0; + int64_t nextConnectionId = 0; + // Use ordered map for deterministic output. + std::map info; + }; + + PrunedGraph breakCycles(const Graph &input) { + assert(input.isCyclic() && "input graph is not cyclic"); + + PrunedGraph output; + + DenseMap recMap; + auto cycleBreaker = [&](Graph::Node inNode) -> std::optional { + auto [node, recId] = output.addRecursiveSelfNode(inNode); + recMap[inNode] = recId; + return node; + }; + + CyclicReplacerCache cache(cycleBreaker); + + std::function replaceNode = + [&](Graph::Node inNode) { + auto cacheEntry = cache.lookupOrInit(inNode); + if (std::optional result = cacheEntry.get()) + return *result; + + // Recursively replace its neighbors. + SmallVector neighbors; + if (auto it = input.edges.find(inNode); it != input.edges.end()) + neighbors = SmallVector( + llvm::map_range(it->second, replaceNode)); + + // Create a new node in the output graph. + int64_t recursionIndex = + cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1; + Graph::Node result = output.addNode(inNode, recursionIndex); + + for (Graph::Node neighbor : neighbors) + output.addEdge(result, neighbor); + + cacheEntry.resolve(result); + return result; + }; + + /// Translate starting from each node. + for (Graph::Node root : llvm::make_first_range(input.edges)) + replaceNode(root); + + return output; + } + + /// Helper for serialization tests that allow putting comments in the + /// serialized format. Every line that begins with a `;` is considered a + /// comment. The entire line, incl. the terminating `\n` is removed. + std::string trimComments(StringRef input) { + std::ostringstream oss; + bool isNewLine = false; + bool isComment = false; + for (char c : input) { + // Lines beginning with ';' are comments. + if (isNewLine && c == ';') + isComment = true; + + if (!isComment) + oss << c; + + if (c == '\n') { + isNewLine = true; + isComment = false; + } + } + return oss.str(); + } +}; + +TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) { + // 0 -> 1 -> 2 + // ^ | + // +---------+ + Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}}; + PrunedGraph output = breakCycles(input); + ASSERT_FALSE(output.isCyclic()) << output.serialize(); + EXPECT_EQ(output.serialize(), trimComments(R"(nodes +; root 0 +0: n0<0> +1: n2 +2: n1 +3: n0<0> +; root 1 +4: n2 +; root 2 +5: n1 +edges +1: 0 +2: 1 +3: 2 +4: 3 +5: 4 +)")); +} + +TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) { + // +----> 1 -----+ + // | v + // 0 <---------- 3 + // | ^ + // +----> 2 -----+ + // + // Two loops: + // 0 -> 1 -> 3 -> 0 + // 0 -> 2 -> 3 -> 0 + Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}}; + PrunedGraph output = breakCycles(input); + ASSERT_FALSE(output.isCyclic()) << output.serialize(); + EXPECT_EQ(output.serialize(), trimComments(R"(nodes +; root 0 +0: n0<0> +1: n3 +2: n1 +3: n2 +4: n0<0> +; root 1 +5: n3 +6: n1 +; root 2 +7: n2 +edges +1: 0 +2: 1 +3: 1 +4: 2 3 +5: 4 +6: 5 +7: 5 +)")); +} + +TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) { + // +----> 1 -----+ + // | ^ v + // 0 <----+----- 2 + // + // Two nested loops: + // 0 -> 1 -> 2 -> 0 + // 1 -> 2 -> 1 + Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}}; + PrunedGraph output = breakCycles(input); + ASSERT_FALSE(output.isCyclic()) << output.serialize(); + EXPECT_EQ(output.serialize(), trimComments(R"(nodes +; root 0 +0: n0<0> +1: n1<1> +2: n2 +3: n1<1> +4: n0<0> +; root 1 +5: n1<2> +6: n2 +7: n1<2> +; root 2 +8: n2 +edges +2: 0 1 +3: 2 +4: 3 +6: 4 5 +7: 6 +8: 4 7 +)")); +} + +TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) { + // +----> 1 -----+ + // | ^ v + // 0 <----+----- 3 + // | v ^ + // +----> 2 -----+ + // + // Two sets of nested loops: + // 0 -> 1 -> 3 -> 0 + // 1 -> 3 -> 1 + // 0 -> 2 -> 3 -> 0 + // 2 -> 3 -> 2 + Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}}; + PrunedGraph output = breakCycles(input); + ASSERT_FALSE(output.isCyclic()) << output.serialize(); + EXPECT_EQ(output.serialize(), trimComments(R"(nodes +; root 0 +0: n0<0> +1: n1<1> +2: n3<2> +3: n2 +4: n3<2> +5: n1<1> +6: n2<3> +7: n3 +8: n2<3> +9: n0<0> +; root 1 +10: n1<4> +11: n3<5> +12: n2 +13: n3<5> +14: n1<4> +; root 2 +15: n2<6> +16: n3 +17: n2<6> +; root 3 +18: n3 +edges +; root 0 +3: 2 +4: 0 1 3 +5: 4 +7: 0 5 6 +8: 7 +9: 5 8 +; root 1 +12: 11 +13: 9 10 12 +14: 13 +; root 2 +16: 9 14 15 +17: 16 +; root 3 +18: 9 14 17 +)")); +} From 8d1dd886c4a80507c8a97dda15e91acbfa7c3619 Mon Sep 17 00:00:00 2001 From: Billy Zhu Date: Tue, 9 Jul 2024 10:27:13 -0700 Subject: [PATCH 2/5] refactor attrtype replacers & add tests --- mlir/include/mlir/IR/AttrTypeSubElements.h | 138 ++++++++++-- mlir/lib/IR/AttrTypeSubElements.cpp | 146 ++++++++++--- mlir/unittests/IR/AttrTypeReplacerTest.cpp | 231 +++++++++++++++++++++ mlir/unittests/IR/CMakeLists.txt | 1 + 4 files changed, 467 insertions(+), 49 deletions(-) create mode 100644 mlir/unittests/IR/AttrTypeReplacerTest.cpp diff --git a/mlir/include/mlir/IR/AttrTypeSubElements.h b/mlir/include/mlir/IR/AttrTypeSubElements.h index 3105040b876317c..234767deea00afd 100644 --- a/mlir/include/mlir/IR/AttrTypeSubElements.h +++ b/mlir/include/mlir/IR/AttrTypeSubElements.h @@ -16,6 +16,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Visitors.h" +#include "mlir/Support/CyclicReplacerCache.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include @@ -116,9 +117,21 @@ class AttrTypeWalker { /// AttrTypeReplacer //===----------------------------------------------------------------------===// -/// This class provides a utility for replacing attributes/types, and their sub -/// elements. Multiple replacement functions may be registered. -class AttrTypeReplacer { +namespace detail { + +/// This class provides a base utility for replacing attributes/types, and their +/// sub elements. Multiple replacement functions may be registered. +/// +/// This base utility is uncached. Users can choose between two cached versions +/// of this replacer: +/// * For non-cyclic replacer logic, use `AttrTypeReplacer`. +/// * For cyclic replacer logic, use `CyclicAttrTypeReplacer`. +/// +/// Concrete implementations implement the following `replace` entry functions: +/// * Attribute replace(Attribute attr); +/// * Type replace(Type type); +template +class AttrTypeReplacerBase { public: //===--------------------------------------------------------------------===// // Application @@ -139,12 +152,6 @@ class AttrTypeReplacer { bool replaceLocs = false, bool replaceTypes = false); - /// Replace the given attribute/type, and recursively replace any sub - /// elements. Returns either the new attribute/type, or nullptr in the case of - /// failure. - Attribute replace(Attribute attr); - Type replace(Type type); - //===--------------------------------------------------------------------===// // Registration //===--------------------------------------------------------------------===// @@ -206,21 +213,114 @@ class AttrTypeReplacer { }); } -private: - /// Internal implementation of the `replace` methods above. - template - T replaceImpl(T element, ReplaceFns &replaceFns); - - /// Replace the sub elements of the given interface. - template - T replaceSubElements(T interface); +protected: + /// Invokes the registered replacement functions from most recently registered + /// to least recently registered until a successful replacement is returned. + /// Unless skipping is requested, invokes `replace` on sub-elements of the + /// current attr/type. + Attribute replaceBase(Attribute attr); + Type replaceBase(Type type); +private: /// The set of replacement functions that map sub elements. std::vector> attrReplacementFns; std::vector> typeReplacementFns; +}; + +} // namespace detail + +/// This is an attribute/type replacer that is naively cached. It is best used +/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any +/// re-occurrence of an in-progress element will be skipped. +class AttrTypeReplacer : public detail::AttrTypeReplacerBase { +public: + Attribute replace(Attribute attr); + Type replace(Type type); + +private: + /// Shared concrete implementation of the public `replace` functions. Invokes + /// `replaceBase` with caching. + template + T cachedReplaceImpl(T element); + + // Stores the opaque pointer of an attribute or type. + DenseMap cache; +}; + +/// This is an attribute/type replacer that supports custom handling of cycles +/// in the replacer logic. In addition to registering replacer functions, it +/// allows registering cycle-breaking functions in the same style. +class CyclicAttrTypeReplacer + : public detail::AttrTypeReplacerBase { +public: + CyclicAttrTypeReplacer(); - /// The set of cached mappings for attributes/types. - DenseMap attrTypeMap; + //===--------------------------------------------------------------------===// + // Application + //===--------------------------------------------------------------------===// + + Attribute replace(Attribute attr); + Type replace(Type type); + + //===--------------------------------------------------------------------===// + // Registration + //===--------------------------------------------------------------------===// + + /// A cycle-breaking function. This is invoked if the same element is asked to + /// be replaced again when the first instance of it is still being replaced. + /// This function must not perform any more recursive `replace` calls. + /// If it is able to break the cycle, it should return a replacement result. + /// Otherwise, it can return std::nullopt to defer cycle breaking to the next + /// repeated element. However, the user must guarantee that, in any possible + /// cycle, there always exists at least one element that can break the cycle. + template + using CycleBreakerFn = std::function(T)>; + + /// Register a cycle-breaking function. + /// When breaking cycles, the mostly recently added cycle-breaking functions + /// will be invoked first. + void addCycleBreaker(CycleBreakerFn fn); + void addCycleBreaker(CycleBreakerFn fn); + + /// Register a cycle-breaking function that doesn't match the default + /// signature. + template >::template arg_t<0>, + typename BaseT = std::conditional_t, + Attribute, Type>> + std::enable_if_t> addCycleBreaker(FnT &&callback) { + addCycleBreaker([callback = std::forward(callback)]( + BaseT base) -> std::optional { + if (auto derived = dyn_cast(base)) + return callback(derived); + return std::nullopt; + }); + } + +private: + /// Invokes the registered cycle-breaker functions from most recently + /// registered to least recently registered until a successful result is + /// returned. + std::optional breakCycleImpl(void *element); + + /// Shared concrete implementation of the public `replace` functions. + template + T cachedReplaceImpl(T element); + + /// The set of registered cycle-breaker functions. + std::vector> attrCycleBreakerFns; + std::vector> typeCycleBreakerFns; + + /// A cache of previously-replaced attr/types. + /// The key of the cache is the opaque value of an AttrOrType. Using + /// AttrOrType allows distinguishing between the two types when invoking + /// cycle-breakers. Using its opaque value avoids the cyclic dependency issue + /// of directly using `AttrOrType` to instantiate the cache. + /// The value of the cache is just the opaque value of the attr/type itself + /// (not the PointerUnion). + using AttrOrType = PointerUnion; + CyclicReplacerCache cache; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AttrTypeSubElements.cpp b/mlir/lib/IR/AttrTypeSubElements.cpp index 79b04966be6eb19..783236ed3a9df67 100644 --- a/mlir/lib/IR/AttrTypeSubElements.cpp +++ b/mlir/lib/IR/AttrTypeSubElements.cpp @@ -67,22 +67,28 @@ WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) { } //===----------------------------------------------------------------------===// -/// AttrTypeReplacer +/// AttrTypeReplacerBase //===----------------------------------------------------------------------===// -void AttrTypeReplacer::addReplacement(ReplaceFn fn) { +template +void detail::AttrTypeReplacerBase::addReplacement( + ReplaceFn fn) { attrReplacementFns.emplace_back(std::move(fn)); } -void AttrTypeReplacer::addReplacement(ReplaceFn fn) { + +template +void detail::AttrTypeReplacerBase::addReplacement( + ReplaceFn fn) { typeReplacementFns.push_back(std::move(fn)); } -void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs, - bool replaceLocs, bool replaceTypes) { +template +void detail::AttrTypeReplacerBase::replaceElementsIn( + Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { // Functor that replaces the given element if the new value is different, // otherwise returns nullptr. auto replaceIfDifferent = [&](auto element) { - auto replacement = replace(element); + auto replacement = static_cast(this)->replace(element); return (replacement && replacement != element) ? replacement : nullptr; }; @@ -127,17 +133,16 @@ void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs, } } -void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op, - bool replaceAttrs, - bool replaceLocs, - bool replaceTypes) { +template +void detail::AttrTypeReplacerBase::recursivelyReplaceElementsIn( + Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { op->walk([&](Operation *nestedOp) { replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes); }); } -template -static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, +template +static void updateSubElementImpl(T element, Replacer &replacer, SmallVectorImpl &newElements, FailureOr &changed) { // Bail early if we failed at any point. @@ -160,18 +165,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, } } -template -T AttrTypeReplacer::replaceSubElements(T interface) { +template +static T replaceSubElements(T interface, Replacer &replacer) { // Walk the current sub-elements, replacing them as necessary. SmallVector newAttrs; SmallVector newTypes; FailureOr changed = false; interface.walkImmediateSubElements( [&](Attribute element) { - updateSubElementImpl(element, *this, newAttrs, changed); + updateSubElementImpl(element, replacer, newAttrs, changed); }, [&](Type element) { - updateSubElementImpl(element, *this, newTypes, changed); + updateSubElementImpl(element, replacer, newTypes, changed); }); if (failed(changed)) return nullptr; @@ -184,13 +189,9 @@ T AttrTypeReplacer::replaceSubElements(T interface) { } /// Shared implementation of replacing a given attribute or type element. -template -T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) { - const void *opaqueElement = element.getAsOpaquePointer(); - auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement); - if (!inserted) - return T::getFromOpaquePointer(it->second); - +template +static T replaceElementImpl(T element, ReplaceFns &replaceFns, + Replacer &replacer) { T result = element; WalkResult walkResult = WalkResult::advance(); for (auto &replaceFn : llvm::reverse(replaceFns)) { @@ -202,29 +203,114 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) { // If an error occurred, return nullptr to indicate failure. if (walkResult.wasInterrupted() || !result) { - attrTypeMap[opaqueElement] = nullptr; return nullptr; } // Handle replacing sub-elements if this element is also a container. if (!walkResult.wasSkipped()) { // Replace the sub elements of this element, bailing if we fail. - if (!(result = replaceSubElements(result))) { - attrTypeMap[opaqueElement] = nullptr; + if (!(result = replaceSubElements(result, replacer))) { return nullptr; } } - attrTypeMap[opaqueElement] = result.getAsOpaquePointer(); + return result; +} + +template +Attribute detail::AttrTypeReplacerBase::replaceBase(Attribute attr) { + return replaceElementImpl(attr, attrReplacementFns, + *static_cast(this)); +} + +template +Type detail::AttrTypeReplacerBase::replaceBase(Type type) { + return replaceElementImpl(type, typeReplacementFns, + *static_cast(this)); +} + +//===----------------------------------------------------------------------===// +/// AttrTypeReplacer +//===----------------------------------------------------------------------===// + +template class detail::AttrTypeReplacerBase; + +template +T AttrTypeReplacer::cachedReplaceImpl(T element) { + const void *opaqueElement = element.getAsOpaquePointer(); + auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement); + if (!inserted) + return T::getFromOpaquePointer(it->second); + + T result = replaceBase(element); + + cache[opaqueElement] = result.getAsOpaquePointer(); return result; } Attribute AttrTypeReplacer::replace(Attribute attr) { - return replaceImpl(attr, attrReplacementFns); + return cachedReplaceImpl(attr); } -Type AttrTypeReplacer::replace(Type type) { - return replaceImpl(type, typeReplacementFns); +Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); } + +//===----------------------------------------------------------------------===// +/// CyclicAttrTypeReplacer +//===----------------------------------------------------------------------===// + +template class detail::AttrTypeReplacerBase; + +CyclicAttrTypeReplacer::CyclicAttrTypeReplacer() + : cache([&](void *attr) { return breakCycleImpl(attr); }) {} + +void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn fn) { + attrCycleBreakerFns.emplace_back(std::move(fn)); +} + +void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn fn) { + typeCycleBreakerFns.emplace_back(std::move(fn)); +} + +template +T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) { + void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue(); + CyclicReplacerCache::CacheEntry cacheEntry = + cache.lookupOrInit(opaqueTaggedElement); + if (auto resultOpt = cacheEntry.get()) + return T::getFromOpaquePointer(*resultOpt); + + T result = replaceBase(element); + + cacheEntry.resolve(result.getAsOpaquePointer()); + return result; +} + +Attribute CyclicAttrTypeReplacer::replace(Attribute attr) { + return cachedReplaceImpl(attr); +} + +Type CyclicAttrTypeReplacer::replace(Type type) { + return cachedReplaceImpl(type); +} + +std::optional +CyclicAttrTypeReplacer::breakCycleImpl(void *element) { + AttrOrType attrType = AttrOrType::getFromOpaqueValue(element); + if (auto attr = dyn_cast(attrType)) { + for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) { + if (std::optional newRes = cyclicReplaceFn(attr)) { + return newRes->getAsOpaquePointer(); + } + } + } else { + auto type = dyn_cast(attrType); + for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) { + if (std::optional newRes = cyclicReplaceFn(type)) { + return newRes->getAsOpaquePointer(); + } + } + } + return std::nullopt; } //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/IR/AttrTypeReplacerTest.cpp b/mlir/unittests/IR/AttrTypeReplacerTest.cpp new file mode 100644 index 000000000000000..c7b42eb267c7ade --- /dev/null +++ b/mlir/unittests/IR/AttrTypeReplacerTest.cpp @@ -0,0 +1,231 @@ +//===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AttrTypeSubElements.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// CyclicAttrTypeReplacer +//===----------------------------------------------------------------------===// + +TEST(CyclicAttrTypeReplacerTest, testNoRecursion) { + MLIRContext ctx; + + CyclicAttrTypeReplacer replacer; + replacer.addReplacement([&](BoolAttr b) { + return StringAttr::get(&ctx, b.getValue() ? "true" : "false"); + }); + + EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)), + StringAttr::get(&ctx, "true")); + EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)), + StringAttr::get(&ctx, "false")); + EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)), + mlir::UnitAttr::get(&ctx)); +} + +TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) { + MLIRContext ctx; + Builder b(&ctx); + + CyclicAttrTypeReplacer replacer; + // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ... + replacer.addReplacement([&](IntegerAttr attr) { + return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3)); + }); + // The first repeat of any integer attr is pruned into a unit attr. + replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); }); + + // No recursion case. + EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)), + mlir::UnitAttr::get(&ctx)); + // Starting at 0. + EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx)); + // Starting at 2. + EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx)); +} + +//===----------------------------------------------------------------------===// +// CyclicAttrTypeReplacerTest: ChainRecursion +//===----------------------------------------------------------------------===// + +class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test { +public: + CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) { + // IntegerType + // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>. + // This will create a chain of infinite length without recursion pruning. + replacer.addReplacement([&](mlir::IntegerType intType) { + ++invokeCount; + return b.getFunctionType( + {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)}); + }); + } + + void setBaseCase(std::optional pruneAt) { + replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) { + return (!pruneAt || intType.getWidth() == *pruneAt) + ? std::make_optional(b.getIndexType()) + : std::nullopt; + }); + } + + Type getFunctionTypeChain(unsigned N) { + Type type = b.getIndexType(); + for (unsigned i = 0; i < N; i++) + type = b.getFunctionType({}, type); + return type; + }; + + MLIRContext ctx; + Builder b; + CyclicAttrTypeReplacer replacer; + int invokeCount = 0; +}; + +TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) { + setBaseCase(std::nullopt); + + // No recursion case. + EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType()); + EXPECT_EQ(invokeCount, 0); + + // Starting at 0. Cycle length is 3. + invokeCount = 0; + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), + getFunctionTypeChain(3)); + EXPECT_EQ(invokeCount, 3); + + // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. + invokeCount = 0; + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeChain(5)); + EXPECT_EQ(invokeCount, 2); +} + +TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere1) { + setBaseCase(std::nullopt); + + // Starting at 1. Cycle length is 3. + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeChain(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific0) { + setBaseCase(0); + + // Starting at 0. Cycle length is 3. + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), + getFunctionTypeChain(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific1) { + setBaseCase(0); + + // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeChain(5)); + EXPECT_EQ(invokeCount, 5); +} + +//===----------------------------------------------------------------------===// +// CyclicAttrTypeReplacerTest: BranchingRecusion +//===----------------------------------------------------------------------===// + +class CyclicAttrTypeReplacerBranchingRecusionPruningTest + : public ::testing::Test { +public: + CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx) { + // IntegerType + // ==> FunctionType< + // IntegerType< width = (N+1) % 3> => + // IntegerType< width = (N+1) % 3>>. + // This will create a binary tree of infinite depth without pruning. + replacer.addReplacement([&](mlir::IntegerType intType) { + ++invokeCount; + Type child = mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3); + return b.getFunctionType({child}, {child}); + }); + } + + void setBaseCase(std::optional pruneAt) { + replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) { + return (!pruneAt || intType.getWidth() == *pruneAt) + ? std::make_optional(b.getIndexType()) + : std::nullopt; + }); + } + + Type getFunctionTypeTree(unsigned N) { + Type type = b.getIndexType(); + for (unsigned i = 0; i < N; i++) + type = b.getFunctionType(type, type); + return type; + }; + + MLIRContext ctx; + Builder b; + CyclicAttrTypeReplacer replacer; + int invokeCount = 0; +}; + +TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere0) { + setBaseCase(std::nullopt); + + // No recursion case. + EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType()); + EXPECT_EQ(invokeCount, 0); + + // Starting at 0. Cycle length is 3. + invokeCount = 0; + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), + getFunctionTypeTree(3)); + // Since both branches are identical, this should incur linear invocations + // of the replacement function instead of exponential. + EXPECT_EQ(invokeCount, 3); + + // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. + invokeCount = 0; + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeTree(5)); + EXPECT_EQ(invokeCount, 2); +} + +TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere1) { + setBaseCase(std::nullopt); + + // Starting at 1. Cycle length is 3. + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeTree(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific0) { + setBaseCase(0); + + // Starting at 0. Cycle length is 3. + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), + getFunctionTypeTree(3)); + EXPECT_EQ(invokeCount, 3); +} + +TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific1) { + setBaseCase(0); + + // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). + EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), + getFunctionTypeTree(5)); + EXPECT_EQ(invokeCount, 5); +} diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 71f8f449756ec01..05cb36e19031635 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_unittest(MLIRIRTests AffineExprTest.cpp AffineMapTest.cpp AttributeTest.cpp + AttrTypeReplacerTest.cpp DialectTest.cpp InterfaceTest.cpp IRMapping.cpp From 5ccbf4bc13ac48240d9fcb14411617809e527304 Mon Sep 17 00:00:00 2001 From: Billy Zhu Date: Tue, 9 Jul 2024 13:58:34 -0700 Subject: [PATCH 3/5] typo & comments --- mlir/unittests/Support/CyclicReplacerCacheTest.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp index 23748e29765cbd8..a4a92dbe147d446 100644 --- a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp +++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp @@ -195,17 +195,19 @@ class CachedCyclicReplacerGraphReplacement : public ::testing::Test { /// A Graph where edges that used to cause cycles (back-edges) are now /// represented with an indirection (a recursionId). /// - /// In addition to each node being an integer, each node also tracks the - /// original integer id it had in the original graph. This way for every + /// In addition to each node having an integer ID, each node also tracks the + /// original integer ID it had in the original graph. This way for every /// back-edge, we can represent it as pointing to a new instance of the /// original node. Then we mark the original node and the new instance with /// a new unique recursionId to indicate that they're supposed to be the same - /// graph. + /// node. struct PrunedGraph { using Node = Graph::Node; struct NodeInfo { Graph::Node originalId; - // A negative recursive index means not recursive. + /// A negative recursive index means not recursive. Otherwise nodes with + /// the same originalId & recursionId are the same node in the original + /// graph. int64_t recursionId; }; @@ -243,7 +245,7 @@ class CachedCyclicReplacerGraphReplacement : public ::testing::Test { Graph connections; int64_t nextRecursionId = 0; int64_t nextConnectionId = 0; - // Use ordered map for deterministic output. + /// Use ordered map for deterministic output. std::map info; }; From a27223a7375afaaf5abd66c26d6b0f4fda372abf Mon Sep 17 00:00:00 2001 From: Billy Zhu Date: Thu, 11 Jul 2024 10:59:52 -0700 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: Jeff Niu --- mlir/include/mlir/Support/CyclicReplacerCache.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h index 9a703676fff11e4..7ad8c717199ecfc 100644 --- a/mlir/include/mlir/Support/CyclicReplacerCache.h +++ b/mlir/include/mlir/Support/CyclicReplacerCache.h @@ -82,14 +82,14 @@ class CyclicReplacerCache { } /// Get the resolved result if one exists. - std::optional get() { return result; } + const std::optional &get() { return result; } private: friend class CyclicReplacerCache; CacheEntry() = delete; CacheEntry(CyclicReplacerCache &cache, InT element, std::optional result = std::nullopt) - : cache(cache), element(element), result(result) {} + : cache(cache), element(std::move(element)), result(result) {} CyclicReplacerCache &cache; InT element; @@ -153,7 +153,7 @@ CyclicReplacerCache::lookupOrInit(const InT &element) { return CacheEntry(*this, element, it->second); if (auto it = dependentCache.find(element); it != dependentCache.end()) { - // pdate the current top frame (the element that invoked this current + // Update the current top frame (the element that invoked this current // replacement) to include any dependencies the cache entry had. ReplacementFrame &currFrame = replacementStack.back(); currFrame.dependentFrames.insert(it->second.highestDependentFrame); From 1abb63f81154de65757b25717a1e616b5b8971d1 Mon Sep 17 00:00:00 2001 From: Billy Zhu Date: Thu, 11 Jul 2024 11:18:02 -0700 Subject: [PATCH 5/5] double down on trivial type support --- .../mlir/Support/CyclicReplacerCache.h | 29 ++++++++++--------- .../Support/CyclicReplacerCacheTest.cpp | 4 +++ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h index 7ad8c717199ecfc..42428c1507ffb5d 100644 --- a/mlir/include/mlir/Support/CyclicReplacerCache.h +++ b/mlir/include/mlir/Support/CyclicReplacerCache.h @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_SUPPORT_CACHINGREPLACER_H -#define MLIR_SUPPORT_CACHINGREPLACER_H +#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H +#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H #include "mlir/IR/Visitors.h" #include "llvm/ADT/DenseSet.h" @@ -43,13 +43,16 @@ namespace mlir { /// any given cycle can perform pruning. Even if not, an assertion will /// eventually be tripped instead of infinite recursion (the run-time is /// linearly bounded by the maximum cycle length of its input). +/// +/// WARNING: This class works best with InT & OutT that are trivial scalar +/// types. The input/output elements will be frequently copied and hashed. template class CyclicReplacerCache { public: /// User-provided replacement function & cycle-breaking functions. /// The cycle-breaking function must not make any more recursive invocations /// to this cached replacer. - using CycleBreakerFn = std::function(const InT &)>; + using CycleBreakerFn = std::function(InT)>; CyclicReplacerCache() = delete; CyclicReplacerCache(CycleBreakerFn cycleBreaker) @@ -77,12 +80,12 @@ class CyclicReplacerCache { /// in the cache. void resolve(OutT result) { assert(!this->result && "cache entry already resolved"); - this->result = result; cache.finalizeReplacement(element, result); + this->result = std::move(result); } /// Get the resolved result if one exists. - const std::optional &get() { return result; } + const std::optional &get() const { return result; } private: friend class CyclicReplacerCache; @@ -106,11 +109,11 @@ class CyclicReplacerCache { /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and /// the first retrieved CacheEntry must be resolved last. This should be /// natural when used as a stack / inside recursion. - CacheEntry lookupOrInit(const InT &element); + CacheEntry lookupOrInit(InT element); private: /// Register the replacement in the cache and update the replacementStack. - void finalizeReplacement(const InT &element, const OutT &result); + void finalizeReplacement(InT element, OutT result); CycleBreakerFn cycleBreaker; DenseMap standaloneCache; @@ -145,7 +148,7 @@ class CyclicReplacerCache { template typename CyclicReplacerCache::CacheEntry -CyclicReplacerCache::lookupOrInit(const InT &element) { +CyclicReplacerCache::lookupOrInit(InT element) { assert(!resolvingCycle && "illegal recursive invocation while breaking cycle"); @@ -195,8 +198,8 @@ CyclicReplacerCache::lookupOrInit(const InT &element) { } template -void CyclicReplacerCache::finalizeReplacement(const InT &element, - const OutT &result) { +void CyclicReplacerCache::finalizeReplacement(InT element, + OutT result) { ReplacementFrame &currFrame = replacementStack.back(); // With the conclusion of this replacement frame, the current element is no // longer a dependent element. @@ -249,7 +252,7 @@ void CyclicReplacerCache::finalizeReplacement(const InT &element, template class CachedCyclicReplacer { public: - using ReplacerFn = std::function; + using ReplacerFn = std::function; using CycleBreakerFn = typename CyclicReplacerCache::CycleBreakerFn; @@ -257,7 +260,7 @@ class CachedCyclicReplacer { CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker) : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {} - OutT operator()(const InT &element) { + OutT operator()(InT element) { auto cacheEntry = cache.lookupOrInit(element); if (std::optional result = cacheEntry.get()) return *result; @@ -274,4 +277,4 @@ class CachedCyclicReplacer { } // namespace mlir -#endif // MLIR_SUPPORT_CACHINGREPLACER_H +#endif // MLIR_SUPPORT_CYCLICREPLACERCACHE_H diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp index a4a92dbe147d446..ca02a3d692b2a8d 100644 --- a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp +++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp @@ -47,6 +47,7 @@ TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) { /// infinitely long vector. The cycle-breaker function prunes this infinite /// recursion in the replacer logic by returning an empty vector upon the first /// re-occurrence of an input value. +namespace { class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test { public: // N ==> (N+1) % 3 @@ -71,6 +72,7 @@ class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test { int invokeCount = 0; std::optional baseCase = std::nullopt; }; +} // namespace TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) { // Starting at 0. Cycle length is 3. @@ -128,6 +130,7 @@ TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) { /// - PrunedGraph /// - A Graph where edges that used to cause cycles are now represented with /// an indirection (a recursionId). +namespace { class CachedCyclicReplacerGraphReplacement : public ::testing::Test { public: /// A directed graph where nodes are non-negative integers. @@ -317,6 +320,7 @@ class CachedCyclicReplacerGraphReplacement : public ::testing::Test { return oss.str(); } }; +} // namespace TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) { // 0 -> 1 -> 2