Skip to content

Commit

Permalink
[MLIR] Cyclic AttrType Replacer (llvm#98206)
Browse files Browse the repository at this point in the history
The current `AttrTypeReplacer` does not allow for custom handling of
replacer functions that may cause self-recursion. For example, the
replacement of one attr/type may depend on the replacement of another
attr/type (by calling into the replacer manually again), which in turn
may depend on the replacement of the original attr/type.

To enable this functionality, this PR broke out the original
AttrTypeReplacer into two parts:
- An uncached base version (`detail::AttrTypeReplacerBase`) that allows
registering replacer functions and has logic for invoking it on
attr/types & their sub-elements
- A cached version (`AttrTypeReplacer`) that provides the same caching
as the original one. This is still the one used everywhere and behavior
is unchanged.

On top of the uncached base version, a `CyclicAttrTypeReplacer` is
introduced that provides caching & cycle-handling for replacer logic
that is cyclic. Cycle-breaking & caching is provided by the
`CyclicReplacerCache` from
llvm#98202.

Both concrete implementations of the uncached base version use CRTP to
avoid dynamic dispatch. The base class merely provides replacer
registration & invocation, and is not meant to be used, or otherwise
extended elsewhere.
  • Loading branch information
zyx-billy authored and aaryanshukla committed Jul 14, 2024
1 parent 3bf250b commit 3f72be3
Show file tree
Hide file tree
Showing 4 changed files with 467 additions and 49 deletions.
138 changes: 119 additions & 19 deletions mlir/include/mlir/IR/AttrTypeSubElements.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/CyclicReplacerCache.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include <optional>
Expand Down Expand Up @@ -116,9 +117,21 @@ class AttrTypeWalker {
/// AttrTypeReplacer
//===----------------------------------------------------------------------===//

/// This class provides a utility for replacing attributes/types, and their sub
/// elements. Multiple replacement functions may be registered.
class AttrTypeReplacer {
namespace detail {

/// This class provides a base utility for replacing attributes/types, and their
/// sub elements. Multiple replacement functions may be registered.
///
/// This base utility is uncached. Users can choose between two cached versions
/// of this replacer:
/// * For non-cyclic replacer logic, use `AttrTypeReplacer`.
/// * For cyclic replacer logic, use `CyclicAttrTypeReplacer`.
///
/// Concrete implementations implement the following `replace` entry functions:
/// * Attribute replace(Attribute attr);
/// * Type replace(Type type);
template <typename Concrete>
class AttrTypeReplacerBase {
public:
//===--------------------------------------------------------------------===//
// Application
Expand All @@ -139,12 +152,6 @@ class AttrTypeReplacer {
bool replaceLocs = false,
bool replaceTypes = false);

/// Replace the given attribute/type, and recursively replace any sub
/// elements. Returns either the new attribute/type, or nullptr in the case of
/// failure.
Attribute replace(Attribute attr);
Type replace(Type type);

//===--------------------------------------------------------------------===//
// Registration
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -206,21 +213,114 @@ class AttrTypeReplacer {
});
}

private:
/// Internal implementation of the `replace` methods above.
template <typename T, typename ReplaceFns>
T replaceImpl(T element, ReplaceFns &replaceFns);

/// Replace the sub elements of the given interface.
template <typename T>
T replaceSubElements(T interface);
protected:
/// Invokes the registered replacement functions from most recently registered
/// to least recently registered until a successful replacement is returned.
/// Unless skipping is requested, invokes `replace` on sub-elements of the
/// current attr/type.
Attribute replaceBase(Attribute attr);
Type replaceBase(Type type);

private:
/// The set of replacement functions that map sub elements.
std::vector<ReplaceFn<Attribute>> attrReplacementFns;
std::vector<ReplaceFn<Type>> typeReplacementFns;
};

} // namespace detail

/// This is an attribute/type replacer that is naively cached. It is best used
/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any
/// re-occurrence of an in-progress element will be skipped.
class AttrTypeReplacer : public detail::AttrTypeReplacerBase<AttrTypeReplacer> {
public:
Attribute replace(Attribute attr);
Type replace(Type type);

private:
/// Shared concrete implementation of the public `replace` functions. Invokes
/// `replaceBase` with caching.
template <typename T>
T cachedReplaceImpl(T element);

// Stores the opaque pointer of an attribute or type.
DenseMap<const void *, const void *> cache;
};

/// This is an attribute/type replacer that supports custom handling of cycles
/// in the replacer logic. In addition to registering replacer functions, it
/// allows registering cycle-breaking functions in the same style.
class CyclicAttrTypeReplacer
: public detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer> {
public:
CyclicAttrTypeReplacer();

/// The set of cached mappings for attributes/types.
DenseMap<const void *, const void *> attrTypeMap;
//===--------------------------------------------------------------------===//
// Application
//===--------------------------------------------------------------------===//

Attribute replace(Attribute attr);
Type replace(Type type);

//===--------------------------------------------------------------------===//
// Registration
//===--------------------------------------------------------------------===//

/// A cycle-breaking function. This is invoked if the same element is asked to
/// be replaced again when the first instance of it is still being replaced.
/// This function must not perform any more recursive `replace` calls.
/// If it is able to break the cycle, it should return a replacement result.
/// Otherwise, it can return std::nullopt to defer cycle breaking to the next
/// repeated element. However, the user must guarantee that, in any possible
/// cycle, there always exists at least one element that can break the cycle.
template <typename T>
using CycleBreakerFn = std::function<std::optional<T>(T)>;

/// Register a cycle-breaking function.
/// When breaking cycles, the mostly recently added cycle-breaking functions
/// will be invoked first.
void addCycleBreaker(CycleBreakerFn<Attribute> fn);
void addCycleBreaker(CycleBreakerFn<Type> fn);

/// Register a cycle-breaking function that doesn't match the default
/// signature.
template <typename FnT,
typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<0>,
typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
Attribute, Type>>
std::enable_if_t<!std::is_same_v<T, BaseT>> addCycleBreaker(FnT &&callback) {
addCycleBreaker([callback = std::forward<FnT>(callback)](
BaseT base) -> std::optional<BaseT> {
if (auto derived = dyn_cast<T>(base))
return callback(derived);
return std::nullopt;
});
}

private:
/// Invokes the registered cycle-breaker functions from most recently
/// registered to least recently registered until a successful result is
/// returned.
std::optional<const void *> breakCycleImpl(void *element);

/// Shared concrete implementation of the public `replace` functions.
template <typename T>
T cachedReplaceImpl(T element);

/// The set of registered cycle-breaker functions.
std::vector<CycleBreakerFn<Attribute>> attrCycleBreakerFns;
std::vector<CycleBreakerFn<Type>> typeCycleBreakerFns;

/// A cache of previously-replaced attr/types.
/// The key of the cache is the opaque value of an AttrOrType. Using
/// AttrOrType allows distinguishing between the two types when invoking
/// cycle-breakers. Using its opaque value avoids the cyclic dependency issue
/// of directly using `AttrOrType` to instantiate the cache.
/// The value of the cache is just the opaque value of the attr/type itself
/// (not the PointerUnion).
using AttrOrType = PointerUnion<Attribute, Type>;
CyclicReplacerCache<void *, const void *> cache;
};

//===----------------------------------------------------------------------===//
Expand Down
146 changes: 116 additions & 30 deletions mlir/lib/IR/AttrTypeSubElements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,28 @@ WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
}

//===----------------------------------------------------------------------===//
/// AttrTypeReplacer
/// AttrTypeReplacerBase
//===----------------------------------------------------------------------===//

void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) {
template <typename Concrete>
void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
ReplaceFn<Attribute> fn) {
attrReplacementFns.emplace_back(std::move(fn));
}
void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) {

template <typename Concrete>
void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
ReplaceFn<Type> fn) {
typeReplacementFns.push_back(std::move(fn));
}

void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
bool replaceLocs, bool replaceTypes) {
template <typename Concrete>
void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn(
Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
// Functor that replaces the given element if the new value is different,
// otherwise returns nullptr.
auto replaceIfDifferent = [&](auto element) {
auto replacement = replace(element);
auto replacement = static_cast<Concrete *>(this)->replace(element);
return (replacement && replacement != element) ? replacement : nullptr;
};

Expand Down Expand Up @@ -127,17 +133,16 @@ void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
}
}

void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op,
bool replaceAttrs,
bool replaceLocs,
bool replaceTypes) {
template <typename Concrete>
void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn(
Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
op->walk([&](Operation *nestedOp) {
replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
});
}

template <typename T>
static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
template <typename T, typename Replacer>
static void updateSubElementImpl(T element, Replacer &replacer,
SmallVectorImpl<T> &newElements,
FailureOr<bool> &changed) {
// Bail early if we failed at any point.
Expand All @@ -160,18 +165,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
}
}

template <typename T>
T AttrTypeReplacer::replaceSubElements(T interface) {
template <typename T, typename Replacer>
static T replaceSubElements(T interface, Replacer &replacer) {
// Walk the current sub-elements, replacing them as necessary.
SmallVector<Attribute, 16> newAttrs;
SmallVector<Type, 16> newTypes;
FailureOr<bool> changed = false;
interface.walkImmediateSubElements(
[&](Attribute element) {
updateSubElementImpl(element, *this, newAttrs, changed);
updateSubElementImpl(element, replacer, newAttrs, changed);
},
[&](Type element) {
updateSubElementImpl(element, *this, newTypes, changed);
updateSubElementImpl(element, replacer, newTypes, changed);
});
if (failed(changed))
return nullptr;
Expand All @@ -184,13 +189,9 @@ T AttrTypeReplacer::replaceSubElements(T interface) {
}

/// Shared implementation of replacing a given attribute or type element.
template <typename T, typename ReplaceFns>
T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
const void *opaqueElement = element.getAsOpaquePointer();
auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
if (!inserted)
return T::getFromOpaquePointer(it->second);

template <typename T, typename ReplaceFns, typename Replacer>
static T replaceElementImpl(T element, ReplaceFns &replaceFns,
Replacer &replacer) {
T result = element;
WalkResult walkResult = WalkResult::advance();
for (auto &replaceFn : llvm::reverse(replaceFns)) {
Expand All @@ -202,29 +203,114 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {

// If an error occurred, return nullptr to indicate failure.
if (walkResult.wasInterrupted() || !result) {
attrTypeMap[opaqueElement] = nullptr;
return nullptr;
}

// Handle replacing sub-elements if this element is also a container.
if (!walkResult.wasSkipped()) {
// Replace the sub elements of this element, bailing if we fail.
if (!(result = replaceSubElements(result))) {
attrTypeMap[opaqueElement] = nullptr;
if (!(result = replaceSubElements(result, replacer))) {
return nullptr;
}
}

attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
return result;
}

template <typename Concrete>
Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) {
return replaceElementImpl(attr, attrReplacementFns,
*static_cast<Concrete *>(this));
}

template <typename Concrete>
Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) {
return replaceElementImpl(type, typeReplacementFns,
*static_cast<Concrete *>(this));
}

//===----------------------------------------------------------------------===//
/// AttrTypeReplacer
//===----------------------------------------------------------------------===//

template class detail::AttrTypeReplacerBase<AttrTypeReplacer>;

template <typename T>
T AttrTypeReplacer::cachedReplaceImpl(T element) {
const void *opaqueElement = element.getAsOpaquePointer();
auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
if (!inserted)
return T::getFromOpaquePointer(it->second);

T result = replaceBase(element);

cache[opaqueElement] = result.getAsOpaquePointer();
return result;
}

Attribute AttrTypeReplacer::replace(Attribute attr) {
return replaceImpl(attr, attrReplacementFns);
return cachedReplaceImpl(attr);
}

Type AttrTypeReplacer::replace(Type type) {
return replaceImpl(type, typeReplacementFns);
Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); }

//===----------------------------------------------------------------------===//
/// CyclicAttrTypeReplacer
//===----------------------------------------------------------------------===//

template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>;

CyclicAttrTypeReplacer::CyclicAttrTypeReplacer()
: cache([&](void *attr) { return breakCycleImpl(attr); }) {}

void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) {
attrCycleBreakerFns.emplace_back(std::move(fn));
}

void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) {
typeCycleBreakerFns.emplace_back(std::move(fn));
}

template <typename T>
T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry =
cache.lookupOrInit(opaqueTaggedElement);
if (auto resultOpt = cacheEntry.get())
return T::getFromOpaquePointer(*resultOpt);

T result = replaceBase(element);

cacheEntry.resolve(result.getAsOpaquePointer());
return result;
}

Attribute CyclicAttrTypeReplacer::replace(Attribute attr) {
return cachedReplaceImpl(attr);
}

Type CyclicAttrTypeReplacer::replace(Type type) {
return cachedReplaceImpl(type);
}

std::optional<const void *>
CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
if (auto attr = dyn_cast<Attribute>(attrType)) {
for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
return newRes->getAsOpaquePointer();
}
}
} else {
auto type = dyn_cast<Type>(attrType);
for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
return newRes->getAsOpaquePointer();
}
}
}
return std::nullopt;
}

//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 3f72be3

Please sign in to comment.