Skip to content

Commit

Permalink
[Unity] QoL improvements for Dataflow matching
Browse files Browse the repository at this point in the history
- Update the zero-parameter `WildcardPattern` constructor to produce a
  valid instance.  Previously, the zero-parameter constructor produced
  a null instance of `WildcardPattern`, which resulted in an error
  when used.  The `WildcardPattern` was expected to be constructed
  through the `Wildcard` function instead.  Since all other
  `DFPattern` child classes could be constructed explicitly, this
  could lead to unexpected outcomes.

- Check for `pattern.defined()` when performing a pattern-match.  If
  a null instance of a pattern is provided, this gives an error
  message with more context than the one raised by `DFPatternFunctor`.

- Expose `RewriteCall` for use in C++.  Previously, it had only been
  exposed through the FFI registry, and had no declaration in a header
  file.
  • Loading branch information
Lunderberg committed Dec 29, 2023
1 parent ca60c63 commit 703aa15
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 10 deletions.
28 changes: 28 additions & 0 deletions include/tvm/relax/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,34 @@ TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
* \return The rewritten or the input function, depending on the pattern matching result.
*/
TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f);

/**
* \brief Rewrite a function with the given pattern and the rewriter function.
*
* Pattern match and replace at an expression level. This level of
* granularity does not allow simultaneous replacement cannot be
* performed. In addition, removal of bindings cannot be performed
* explicitly, and is only done implicitly through RemoveAllUnused.
* See also `RewriteBindings`, which performs replacement on a
* block-level, and does not have these restrictions.
*
* \param pattern The pattern to be replaced
*
* \param rewriter The function to be called on a successful pattern
* matching. Given the matched expression and a map of sub-matches,
* it should return the replacement expression. If the expression
* doesn't require updating (e.g. replacement required checks beyond
* those expressed in the pattern), it should return the expression
* unmodified.
*
* \param func The function to rewrite
*
* \return The updated function, if any updates were applied.
*/
TVM_DLL Function RewriteCall(const DFPattern& pattern,
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter,
Function func);

} // namespace relax
} // namespace tvm

Expand Down
9 changes: 8 additions & 1 deletion include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,14 @@ class WildcardPatternNode : public DFPatternNode {
*/
class WildcardPattern : public DFPattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);
WildcardPattern();

// Declaring WildcardPattern declared as non-nullable avoids the
// default zero-parameter constructor for ObjectRef with `data_ =
// nullptr`. This allows a zero-parameter constructor to be
// declared here, to create a valid wildcard instance.

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);
};

/*!
Expand Down
14 changes: 9 additions & 5 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ void DFPatternMatcher::ClearMap(size_t watermark) {
}

bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) {
CHECK(pattern.defined()) << "Null pattern found when matching against " << expr0;

auto expr = TryGetValOfVar(expr0, var2val_);
if (memoize_ && memo_.count(pattern)) {
ICHECK_EQ(memo_[pattern].size(), 1);
Expand Down Expand Up @@ -1106,12 +1108,14 @@ Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Functio
return PatternRewriter::Run(ctx, rewriter, f);
}

TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call")
.set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) {
return PatternRewriter::Run(pat, rewriter, f);
});

TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings);

Function RewriteCall(const DFPattern& pat,
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter, Function f) {
return PatternRewriter::Run(pat, rewriter, f);
}

TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall);

} // namespace relax
} // namespace tvm
6 changes: 2 additions & 4 deletions src/relax/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,8 @@ RELAX_PATTERN_PRINTER_DEF(NotPatternNode,
[](auto p, auto node) { p->stream << "!(" << node->reject << ")"; });

TVM_REGISTER_NODE_TYPE(WildcardPatternNode);
TVM_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() {
auto w = WildcardPattern(make_object<WildcardPatternNode>());
return w;
});
WildcardPattern::WildcardPattern() { data_ = make_object<WildcardPatternNode>(); }
TVM_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { return WildcardPattern(); });
RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; });

TVM_REGISTER_NODE_TYPE(TypePatternNode);
Expand Down

0 comments on commit 703aa15

Please sign in to comment.