Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Cyclic AttrType Replacer #98206

Merged
merged 7 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 119 additions & 19 deletions mlir/include/mlir/IR/AttrTypeSubElements.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/CyclicReplacerCache.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include <optional>
Expand Down Expand Up @@ -116,9 +117,21 @@ class AttrTypeWalker {
/// AttrTypeReplacer
//===----------------------------------------------------------------------===//

/// This class provides a utility for replacing attributes/types, and their sub
/// elements. Multiple replacement functions may be registered.
class AttrTypeReplacer {
namespace detail {

/// This class provides a base utility for replacing attributes/types, and their
/// sub elements. Multiple replacement functions may be registered.
///
/// This base utility is uncached. Users can choose between two cached versions
/// of this replacer:
/// * For non-cyclic replacer logic, use `AttrTypeReplacer`.
/// * For cyclic replacer logic, use `CyclicAttrTypeReplacer`.
///
/// Concrete implementations implement the following `replace` entry functions:
/// * Attribute replace(Attribute attr);
/// * Type replace(Type type);
template <typename Concrete>
class AttrTypeReplacerBase {
public:
//===--------------------------------------------------------------------===//
// Application
Expand All @@ -139,12 +152,6 @@ class AttrTypeReplacer {
bool replaceLocs = false,
bool replaceTypes = false);

/// Replace the given attribute/type, and recursively replace any sub
/// elements. Returns either the new attribute/type, or nullptr in the case of
/// failure.
Attribute replace(Attribute attr);
Type replace(Type type);

//===--------------------------------------------------------------------===//
// Registration
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -206,21 +213,114 @@ class AttrTypeReplacer {
});
}

private:
/// Internal implementation of the `replace` methods above.
template <typename T, typename ReplaceFns>
T replaceImpl(T element, ReplaceFns &replaceFns);

/// Replace the sub elements of the given interface.
template <typename T>
T replaceSubElements(T interface);
protected:
/// Invokes the registered replacement functions from most recently registered
/// to least recently registered until a successful replacement is returned.
/// Unless skipping is requested, invokes `replace` on sub-elements of the
/// current attr/type.
Attribute replaceBase(Attribute attr);
Type replaceBase(Type type);

private:
/// The set of replacement functions that map sub elements.
std::vector<ReplaceFn<Attribute>> attrReplacementFns;
std::vector<ReplaceFn<Type>> typeReplacementFns;
};

} // namespace detail

/// This is an attribute/type replacer that is naively cached. It is best used
/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any
/// re-occurrence of an in-progress element will be skipped.
class AttrTypeReplacer : public detail::AttrTypeReplacerBase<AttrTypeReplacer> {
public:
Attribute replace(Attribute attr);
Type replace(Type type);

private:
/// Shared concrete implementation of the public `replace` functions. Invokes
/// `replaceBase` with caching.
template <typename T>
T cachedReplaceImpl(T element);

// Stores the opaque pointer of an attribute or type.
DenseMap<const void *, const void *> cache;
};

/// This is an attribute/type replacer that supports custom handling of cycles
/// in the replacer logic. In addition to registering replacer functions, it
/// allows registering cycle-breaking functions in the same style.
class CyclicAttrTypeReplacer
: public detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer> {
public:
CyclicAttrTypeReplacer();

/// The set of cached mappings for attributes/types.
DenseMap<const void *, const void *> attrTypeMap;
//===--------------------------------------------------------------------===//
// Application
//===--------------------------------------------------------------------===//

Attribute replace(Attribute attr);
Type replace(Type type);

//===--------------------------------------------------------------------===//
// Registration
//===--------------------------------------------------------------------===//

/// A cycle-breaking function. This is invoked if the same element is asked to
/// be replaced again when the first instance of it is still being replaced.
/// This function must not perform any more recursive `replace` calls.
/// If it is able to break the cycle, it should return a replacement result.
/// Otherwise, it can return std::nullopt to defer cycle breaking to the next
/// repeated element. However, the user must guarantee that, in any possible
/// cycle, there always exists at least one element that can break the cycle.
template <typename T>
using CycleBreakerFn = std::function<std::optional<T>(T)>;

/// Register a cycle-breaking function.
/// When breaking cycles, the mostly recently added cycle-breaking functions
/// will be invoked first.
void addCycleBreaker(CycleBreakerFn<Attribute> fn);
void addCycleBreaker(CycleBreakerFn<Type> fn);

/// Register a cycle-breaking function that doesn't match the default
/// signature.
template <typename FnT,
typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<0>,
typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
Attribute, Type>>
std::enable_if_t<!std::is_same_v<T, BaseT>> addCycleBreaker(FnT &&callback) {
addCycleBreaker([callback = std::forward<FnT>(callback)](
BaseT base) -> std::optional<BaseT> {
if (auto derived = dyn_cast<T>(base))
return callback(derived);
return std::nullopt;
});
}

private:
/// Invokes the registered cycle-breaker functions from most recently
/// registered to least recently registered until a successful result is
/// returned.
std::optional<const void *> breakCycleImpl(void *element);

/// Shared concrete implementation of the public `replace` functions.
template <typename T>
T cachedReplaceImpl(T element);

/// The set of registered cycle-breaker functions.
std::vector<CycleBreakerFn<Attribute>> attrCycleBreakerFns;
std::vector<CycleBreakerFn<Type>> typeCycleBreakerFns;

/// A cache of previously-replaced attr/types.
/// The key of the cache is the opaque value of an AttrOrType. Using
/// AttrOrType allows distinguishing between the two types when invoking
/// cycle-breakers. Using its opaque value avoids the cyclic dependency issue
/// of directly using `AttrOrType` to instantiate the cache.
/// The value of the cache is just the opaque value of the attr/type itself
/// (not the PointerUnion).
using AttrOrType = PointerUnion<Attribute, Type>;
CyclicReplacerCache<void *, const void *> cache;
};

//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading