Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternLang] Add a relay LetPattern #7332

Merged
merged 3 commits into from
Jan 23, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,24 @@ are matched:

assert pat.match(relay.expr.If(cond, x, y))


A Relay ``Let`` expression can be matched if all of its variable, value, and body
are matched:

.. code-block:: python

def test_match_let():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")
cond = x < y
assert pat.match(relay.expr.Let(lv, cond, lv))

Matching Diamonds and Post-Dominator Graphs
*******************************************

Expand Down Expand Up @@ -310,6 +328,7 @@ The high level design is to introduce a language of patterns for now we propose
| is_tuple()
| is_tuple_get_item(pattern, index = None)
| is_if(cond, tru, fls)
| is_let(var, value, body)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
| FunctionPattern(params, body)
Expand Down Expand Up @@ -367,6 +386,16 @@ Function Pattern

Match a Function with a body and parameters

If Pattern
**********

Match an If with condition, true branch, and false branch

Let Pattern
***********

Match a Let with a variable, value, and body

Applications
============

Expand Down
36 changes: 36 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,42 @@ class FunctionPattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode);
};

/*! \brief A binding of a sub-network. */
class LetPatternNode : public DFPatternNode {
public:
/*! \brief The variable we bind to */
DFPattern var;
/*! \brief The value we bind var to */
DFPattern value;
/*! \brief The body of the let binding */
DFPattern body;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.LetPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(LetPatternNode, DFPatternNode);
};

/*!
* \brief Let binding that binds a local var
*/
class LetPattern : public DFPattern {
public:
/*!
* \brief The constructor
* \param var The variable that is bound to.
* \param value The value used to bind to the variable.
* \param body The body of the let binding.
*/
TVM_DLL LetPattern(DFPattern var, DFPattern value, DFPattern body);

TVM_DEFINE_OBJECT_REF_METHODS(LetPattern, DFPattern, LetPatternNode);
};

/*! \brief Tuple of multiple Exprs */
class TuplePattern;
/*! \brief Tuple container */
Expand Down
11 changes: 7 additions & 4 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,19 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const LetPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPatternDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
Expand All @@ -115,9 +116,10 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(LetPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
Expand All @@ -143,10 +145,11 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const DominatorPatternNode* op) override;
void VisitDFPattern_(const ExprPatternNode* op) override;
void VisitDFPattern_(const FunctionPatternNode* op) override;
void VisitDFPattern_(const IfPatternNode* op) override;
void VisitDFPattern_(const LetPatternNode* op) override;
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const IfPatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,29 @@ def is_if(cond, true_branch, false_branch):
return IfPattern(cond, true_branch, false_branch)


def is_let(var, value, body):
"""
Syntatic sugar for creating an IfPattern.
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
var: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the variable of Let.

value: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the value of Let.

body: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the body where the binding is in effect.

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting pattern.
"""
return LetPattern(var, value, body)


def wildcard() -> "DFPattern":
"""
Syntatic sugar for creating a WildcardPattern.
Expand Down Expand Up @@ -579,6 +602,27 @@ def __init__(self, cond: "DFPattern", true_branch: "DFPattern", false_branch: "D
self.__init_handle_by_constructor__(ffi.IfPattern, cond, true_branch, false_branch)


@register_df_node
class LetPattern(DFPattern):
"""A patern matching a Relay If.
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
var: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the variable of Let.

value: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the value of Let.

body: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the body where the binding is in effect.

"""

def __init__(self, var: "DFPattern", value: "DFPattern", body: "DFPattern"):
self.__init_handle_by_constructor__(ffi.LetPattern, var, value, body)


@register_df_node
class TuplePattern(DFPattern):
"""A patern matching a Relay Tuple.
Expand Down
11 changes: 10 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
Expand Down Expand Up @@ -423,6 +424,14 @@ bool DFPatternMatcher::VisitDFPattern_(const IfPatternNode* op, const Expr& expr
return false;
}

bool DFPatternMatcher::VisitDFPattern_(const LetPatternNode* op, const Expr& expr) {
if (const auto* let_node = expr.as<LetNode>()) {
return VisitDFPattern(op->var, let_node->var) && VisitDFPattern(op->value, let_node->value) &&
VisitDFPattern(op->body, let_node->body);
}
return false;
}

Expr InferType(const Expr& expr) {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
Expand Down
22 changes: 22 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")";
});

LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) {
ObjectPtr<LetPatternNode> n = make_object<LetPatternNode>();
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(LetPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern")
.set_body_typed([](DFPattern var, DFPattern value, DFPattern body) {
return LetPattern(var, value, body);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const LetPatternNode*>(ref.get());
p->stream << "LetPatternNode(" << node->var << ", " << node->value << ", " << node->body
<< ")";
});

IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) {
ObjectPtr<IfPatternNode> n = make_object<IfPatternNode>();
n->cond = std::move(cond);
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ void DFPatternVisitor::VisitDFPattern_(const IfPatternNode* op) {
VisitDFPattern(op->false_branch);
}

void DFPatternVisitor::VisitDFPattern_(const LetPatternNode* op) {
VisitDFPattern(op->var);
VisitDFPattern(op->value);
VisitDFPattern(op->body);
}

void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); }

void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
VisitDFPattern(op->false_branch, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->var, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->value, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override {
VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
}
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,18 @@ def test_IfPattern():
assert isinstance(pat.false_branch, VarPattern)


def test_LetPattern():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

assert isinstance(pat, LetPattern)
assert isinstance(pat.var, VarPattern)
assert isinstance(pat.value, CallPattern)
assert isinstance(pat.body, VarPattern)


## MATCHER TESTS


Expand Down Expand Up @@ -233,6 +245,33 @@ def test_no_match_if():
assert not pat.match(relay.expr.If(x < y, y, x))


def test_match_let():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")
cond = x < y
assert pat.match(relay.expr.Let(lv, cond, lv))


def test_no_match_let():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")

assert not pat.match(relay.expr.Let(lv, x > y, lv))
assert not pat.match(relay.expr.Let(lv, x < y, lv * x))


def test_match_option():
x = relay.var("x")
w = relay.var("w")
Expand Down