Skip to content

Commit

Permalink
[PatternLang] Add If pattern (apache#7282)
Browse files Browse the repository at this point in the history
* Add if pattern

commit 1ee052fd494a5bdd881c242c3ea0c95cf2a613e5
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 26 22:19:17 2020 +0900

    add comment

commit c846a6999e9c9e48fbc019780e705a990f46cb22
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 26 21:14:20 2020 +0900

    max_out_size rewrite added to the test

commit 2c7c7fbd0e6563aba694e7fb6baa7bda8e4fadca
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 26 20:57:55 2020 +0900

    max_out_size rewrite working

commit 319e930acb8162c1ec4a5d4fb71d134580a68f13
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 26 20:43:16 2020 +0900

    refactor dyn strided slice pattern

commit fb6917b703440748800bde624bc20efaf5798b8a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 26 11:21:33 2020 +0900

    update NMS pattern following frontend change

commit 255a98f1da8f300d4fe417cce3587c0d71e38ed3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu Dec 24 05:19:31 2020 +0900

    add some comment to explain the pattern

commit 52cea1cc2bff533ca60acfc2416477fc8b058428
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Wed Dec 23 08:35:14 2020 +0900

    revert tutorial change

commit d3e0e0d7e2427c40067d6ad2680ec5b3f0076223
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Wed Dec 23 08:02:29 2020 +0900

    test fixed by setting force_surpress=False

commit 2fa1a574f932001be2d8f601338a342dab92f79c
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Wed Dec 23 07:22:32 2020 +0900

    fixed coord_start

commit 6ba88f27dec1bdb0b0ba746c268591a59264088e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Wed Dec 23 06:50:46 2020 +0900

    add doc

commit 8d386b6a1c92ce4fe3349ff20e320199a1b5b310
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Wed Dec 23 05:27:26 2020 +0900

    updated tutorial

commit 3206b49ecfdd874e0ff8feb0fa586c4c4282f705
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Wed Dec 23 05:04:44 2020 +0900

    update object detection test to add rewrite

commit 74bebb2f4376aeb67d8c4aad395f9f2661fe6b3e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Wed Dec 23 05:02:15 2020 +0900

    add a pattern to rewrite nms to batched nms

commit f410e6dde0ed949b90312c5a7ddbb6c234f9acc1
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 26 22:20:16 2020 +0900

    add comment

commit f1e078b0724bd22e7be0a812055e1c7c650d94da
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 26 19:54:22 2020 +0900

    Add if pattern

* add doc

* add test

* doc formatting

* cpplint fix
  • Loading branch information
masahi committed Jan 18, 2021
1 parent 6b83c13 commit fa9b899
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 0 deletions.
16 changes: 16 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ The next example is matching function nodes with a specific attribute:
f = relay.Function([x, y], x + y).with_attr("Composite", "add")
assert pattern.match(f)
A Relay ``If`` expression can be matched if all of its condition, true branch and false branch
are matched:

.. code-block:: python
def test_match_if():
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)
x = relay.var("x")
y = relay.var("y")
cond = x < y
assert pat.match(relay.expr.If(cond, x, y))
Matching Diamonds and Post-Dominator Graphs
*******************************************
Expand Down Expand Up @@ -294,6 +309,7 @@ The high level design is to introduce a language of patterns for now we propose
| is_op(op_name)
| is_tuple()
| is_tuple_get_item(pattern, index = None)
| is_if(cond, tru, fls)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
| FunctionPattern(params, body)
Expand Down
20 changes: 20 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,26 @@ class TupleGetItemPatternNode : public DFPatternNode {
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode);
};

class IfPatternNode : public DFPatternNode {
public:
DFPattern cond, true_branch, false_branch;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.IfPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(IfPatternNode, DFPatternNode);
};

class IfPattern : public DFPattern {
public:
TVM_DLL IfPattern(DFPattern cond, DFPattern then_clause, DFPattern else_clause);
TVM_DEFINE_OBJECT_REF_METHODS(IfPattern, DFPattern, IfPatternNode);
};

class TupleGetItemPattern : public DFPattern {
public:
TVM_DLL TupleGetItemPattern(DFPattern tuple, int index);
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
Expand All @@ -116,6 +117,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
Expand Down Expand Up @@ -144,6 +146,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const IfPatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,29 @@ def is_tuple_get_item(tuple_value: "DFPattern", index: Optional[int] = None) ->
return TupleGetItemPattern(tuple_value, index)


def is_if(cond, true_branch, false_branch):
"""
Syntatic sugar for creating an IfPattern.
Parameters
----------
cond: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the condition of If.
true_branch: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the true branch of If.
false_branch: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the false branch of If.
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting pattern.
"""
return IfPattern(cond, true_branch, false_branch)


def wildcard() -> "DFPattern":
"""
Syntatic sugar for creating a WildcardPattern.
Expand Down Expand Up @@ -536,6 +559,26 @@ def __init__(
self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body)


@register_df_node
class IfPattern(DFPattern):
"""A patern matching a Relay If.
Parameters
----------
cond: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the condition of If.
true_branch: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the true branch of If.
false_branch: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the false branch of If.
"""

def __init__(self, cond: "DFPattern", true_branch: "DFPattern", false_branch: "DFPattern"):
self.__init_handle_by_constructor__(ffi.IfPattern, cond, true_branch, false_branch)


@register_df_node
class TuplePattern(DFPattern):
"""A patern matching a Relay Tuple.
Expand Down
12 changes: 12 additions & 0 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
Expand Down Expand Up @@ -411,6 +412,17 @@ bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& e
return matches;
}

bool DFPatternMatcher::VisitDFPattern_(const IfPatternNode* op, const Expr& expr) {
if (const auto* if_node = expr.as<IfNode>()) {
auto cond = if_node->cond;
auto true_branch = if_node->true_branch;
auto false_branch = if_node->false_branch;
return VisitDFPattern(op->cond, cond) && VisitDFPattern(op->true_branch, true_branch) &&
VisitDFPattern(op->false_branch, false_branch);
}
return false;
}

Expr InferType(const Expr& expr) {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
Expand Down
22 changes: 22 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")";
});

IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) {
ObjectPtr<IfPatternNode> n = make_object<IfPatternNode>();
n->cond = std::move(cond);
n->true_branch = std::move(true_branch);
n->false_branch = std::move(false_branch);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(IfPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.IfPattern")
.set_body_typed([](DFPattern cond, DFPattern true_branch, DFPattern false_branch) {
return IfPattern(cond, true_branch, false_branch);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IfPatternNode*>(ref.get());
p->stream << "IfPattern(" << node->cond << ", " << node->true_branch << ", "
<< node->false_branch << ")";
});

TuplePattern::TuplePattern(tvm::Array<DFPattern> fields) {
ObjectPtr<TuplePatternNode> n = make_object<TuplePatternNode>();
n->fields = std::move(fields);
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) {
}
}

void DFPatternVisitor::VisitDFPattern_(const IfPatternNode* op) {
VisitDFPattern(op->cond);
VisitDFPattern(op->true_branch);
VisitDFPattern(op->false_branch);
}

void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); }

void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
}
}

void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->cond, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->true_branch, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->false_branch, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override {
VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
}
Expand Down
38 changes: 38 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ def test_AttrPattern():
assert op.attrs["TOpPattern"] == K_ELEMWISE


def test_IfPattern():
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

assert isinstance(pat, IfPattern)
assert isinstance(pat.cond, CallPattern)
assert isinstance(pat.true_branch, VarPattern)
assert isinstance(pat.false_branch, VarPattern)


## MATCHER TESTS


Expand Down Expand Up @@ -198,6 +209,30 @@ def test_no_match_func():
assert not func_pattern.match(relay.Function([x, y], x - y))


def test_match_if():
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

x = relay.var("x")
y = relay.var("y")
cond = x < y

assert pat.match(relay.expr.If(cond, x, y))


def test_no_match_if():
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

x = relay.var("x")
y = relay.var("y")

assert not pat.match(relay.expr.If(x > y, x, y))
assert not pat.match(relay.expr.If(x < y, y, x))


def test_match_option():
x = relay.var("x")
w = relay.var("w")
Expand Down Expand Up @@ -1541,3 +1576,6 @@ def test_partition_constant_embedding():
test_partition_option()
test_match_match()
test_partition_constant_embedding()
test_IfPattern()
test_match_if()
test_no_match_if()

0 comments on commit fa9b899

Please sign in to comment.