diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index 16249377a27d..bbc8e9382ed0 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -68,6 +68,34 @@ TVM_DLL Optional> MatchGraph(const PatternContext& ctx, * \return The rewritten or the input function, depending on the pattern matching result. */ TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f); + +/** + * \brief Rewrite a function with the given pattern and the rewriter function. + * + * Pattern match and replace at an expression level. This level of + * granularity does not allow simultaneous replacement cannot be + * performed. In addition, removal of bindings cannot be performed + * explicitly, and is only done implicitly through RemoveAllUnused. + * See also `RewriteBindings`, which performs replacement on a + * block-level, and does not have these restrictions. + * + * \param pattern The pattern to be replaced + * + * \param rewriter The function to be called on a successful pattern + * matching. Given the matched expression and a map of sub-matches, + * it should return the replacement expression. If the expression + * doesn't require updating (e.g. replacement required checks beyond + * those expressed in the pattern), it should return the expression + * unmodified. + * + * \param func The function to rewrite + * + * \return The updated function, if any updates were applied. + */ +TVM_DLL Function RewriteCall(const DFPattern& pattern, + TypedPackedFunc)> rewriter, + Function func); + } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 933429cb9b24..b634b315d98e 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -727,7 +727,14 @@ class WildcardPatternNode : public DFPatternNode { */ class WildcardPattern : public DFPattern { public: - TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); + WildcardPattern(); + + // Declaring WildcardPattern declared as non-nullable avoids the + // default zero-parameter constructor for ObjectRef with `data_ = + // nullptr`. This allows a zero-parameter constructor to be + // declared here, to create a valid wildcard instance. + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); }; /*! diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 7fb67d9376f5..c2515067edcf 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -78,6 +78,8 @@ void DFPatternMatcher::ClearMap(size_t watermark) { } bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) { + CHECK(pattern.defined()) << "Null pattern found when matching against " << expr0; + auto expr = TryGetValOfVar(expr0, var2val_); if (memoize_ && memo_.count(pattern)) { ICHECK_EQ(memo_[pattern].size(), 1); @@ -1106,12 +1108,14 @@ Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Functio return PatternRewriter::Run(ctx, rewriter, f); } -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call") - .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) { - return PatternRewriter::Run(pat, rewriter, f); - }); - TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); +Function RewriteCall(const DFPattern& pat, + TypedPackedFunc)> rewriter, Function f) { + return PatternRewriter::Run(pat, rewriter, f); +} + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index faa890a12c39..1286a32e4cb8 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -241,10 +241,8 @@ RELAX_PATTERN_PRINTER_DEF(NotPatternNode, [](auto p, auto node) { p->stream << "!(" << node->reject << ")"; }); TVM_REGISTER_NODE_TYPE(WildcardPatternNode); -TVM_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { - auto w = WildcardPattern(make_object()); - return w; -}); +WildcardPattern::WildcardPattern() { data_ = make_object(); } +TVM_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { return WildcardPattern(); }); RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; }); TVM_REGISTER_NODE_TYPE(TypePatternNode);