Skip to content

Commit

Permalink
refactor attrtype replacers & add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zyx-billy committed Jul 9, 2024
1 parent bcc3eda commit 0d38de6
Show file tree
Hide file tree
Showing 4 changed files with 469 additions and 51 deletions.
142 changes: 121 additions & 21 deletions mlir/include/mlir/IR/AttrTypeSubElements.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/CyclicReplacerCache.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include <optional>
Expand Down Expand Up @@ -116,9 +117,21 @@ class AttrTypeWalker {
/// AttrTypeReplacer
//===----------------------------------------------------------------------===//

/// This class provides a utility for replacing attributes/types, and their sub
/// elements. Multiple replacement functions may be registered.
class AttrTypeReplacer {
namespace detail {

/// This class provides a base utility for replacing attributes/types, and their
/// sub elements. Multiple replacement functions may be registered.
///
/// This base utility is uncached. Users can choose between two cached versions
/// of this replacer:
/// * For non-cyclic replacer logic, use `AttrTypeReplacer`.
/// * For cyclic replacer logic, use `CyclicAttrTypeReplacer`.
///
/// Concrete implementations implement the following `replace` entry functions:
/// * Attribute replace(Attribute attr);
/// * Type replace(Type type);
template <typename Concrete>
class AttrTypeReplacerBase {
public:
//===--------------------------------------------------------------------===//
// Application
Expand All @@ -139,12 +152,6 @@ class AttrTypeReplacer {
bool replaceLocs = false,
bool replaceTypes = false);

/// Replace the given attribute/type, and recursively replace any sub
/// elements. Returns either the new attribute/type, or nullptr in the case of
/// failure.
Attribute replace(Attribute attr);
Type replace(Type type);

//===--------------------------------------------------------------------===//
// Registration
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -206,29 +213,122 @@ class AttrTypeReplacer {
});
}

private:
/// Internal implementation of the `replace` methods above.
template <typename T, typename ReplaceFns>
T replaceImpl(T element, ReplaceFns &replaceFns);

/// Replace the sub elements of the given interface.
template <typename T>
T replaceSubElements(T interface);
protected:
/// Invokes the registered replacement functions from most recently registered
/// to least recently registered until a successful replacement is returned.
/// Unless skipping is requested, invokes `replace` on sub-elements of the
/// current attr/type.
Attribute replaceBase(Attribute attr);
Type replaceBase(Type type);

private:
/// The set of replacement functions that map sub elements.
std::vector<ReplaceFn<Attribute>> attrReplacementFns;
std::vector<ReplaceFn<Type>> typeReplacementFns;
};

} // namespace detail

/// This is an attribute/type replacer that is naively cached. It is best used
/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any
/// re-occurrence of an in-progress element will be skipped.
class AttrTypeReplacer : public detail::AttrTypeReplacerBase<AttrTypeReplacer> {
public:
Attribute replace(Attribute attr);
Type replace(Type type);

private:
/// Shared concrete implementation of the public `replace` functions. Invokes
/// `replaceBase` with caching.
template <typename T>
T cachedReplaceImpl(T element);

// Stores the opaque pointer of an attribute or type.
DenseMap<const void *, const void *> cache;
};

/// This is an attribute/type replacer that supports custom handling of cycles
/// in the replacer logic. In addition to registering replacer functions, it
/// allows registering cycle-breaking functions in the same style.
class CyclicAttrTypeReplacer
: public detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer> {
public:
CyclicAttrTypeReplacer();

/// The set of cached mappings for attributes/types.
DenseMap<const void *, const void *> attrTypeMap;
//===--------------------------------------------------------------------===//
// Application
//===--------------------------------------------------------------------===//

Attribute replace(Attribute attr);
Type replace(Type type);

//===--------------------------------------------------------------------===//
// Registration
//===--------------------------------------------------------------------===//

/// A cycle-breaking function. This is invoked if the same element is asked to
/// be replaced again when the first instance of it is still being replaced.
/// This function must not perform any more recursive `replace` calls.
/// If it is able to break the cycle, it should return a replacement result.
/// Otherwise, it can return std::nullopt to defer cycle breaking to the next
/// repeated element. However, the user must guarantee that, in any possible
/// cycle, there always exists at least one element that can break the cycle.
template <typename T>
using CycleBreakerFn = std::function<std::optional<T>(T)>;

/// Register a cycle-breaking function.
/// When breaking cycles, the mostly recently added cycle-breaking functions
/// will be invoked first.
void addCycleBreaker(CycleBreakerFn<Attribute> fn);
void addCycleBreaker(CycleBreakerFn<Type> fn);

/// Register a cycle-breaking function that doesn't match the default
/// signature.
template <typename FnT,
typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<0>,
typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
Attribute, Type>>
std::enable_if_t<!std::is_same_v<T, BaseT>> addCycleBreaker(FnT &&callback) {
addCycleBreaker([callback = std::forward<FnT>(callback)](
BaseT base) -> std::optional<BaseT> {
if (auto derived = dyn_cast<T>(base))
return callback(derived);
return std::nullopt;
});
}

private:
/// Invokes the registered cycle-breaker functions from most recently
/// registered to least recently registered until a successful result is
/// returned.
std::optional<const void *> breakCycleImpl(void *element);

/// Shared concrete implementation of the public `replace` functions.
template <typename T>
T cachedReplaceImpl(T element);

/// The set of registered cycle-breaker functions.
std::vector<CycleBreakerFn<Attribute>> attrCycleBreakerFns;
std::vector<CycleBreakerFn<Type>> typeCycleBreakerFns;

/// A cache of previously-replaced attr/types.
/// The key of the cache is the opaque value of an AttrOrType. Using
/// AttrOrType allows distinguishing between the two types when invoking
/// cycle-breakers. Using its opaque value avoids the cyclic dependency issue
/// of directly using `AttrOrType` to instantiate the cache.
/// The value of the cache is just the opaque value of the attr/type itself
/// (not the PointerUnion).
using AttrOrType = PointerUnion<Attribute, Type>;
CyclicReplacerCache<void *, const void *> cache;
};

//===----------------------------------------------------------------------===//
/// AttrTypeSubElementHandler
//===----------------------------------------------------------------------===//

/// This class is used by AttrTypeSubElementHandler instances to walking sub
/// attributes and types.
/// This class is used by AttrTypeSubElementHandler instances to walking
/// sub attributes and types.
class AttrTypeImmediateSubElementWalker {
public:
AttrTypeImmediateSubElementWalker(function_ref<void(Attribute)> walkAttrsFn,
Expand Down
Loading

0 comments on commit 0d38de6

Please sign in to comment.