Skip to content

Commit

Permalink
constant-folding-free printing
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 committed Mar 6, 2023
1 parent b94c208 commit 1fcf769
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 23 deletions.
34 changes: 18 additions & 16 deletions src/script/printer/tir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Div>("", [](tir::Div node, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a"));
ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b"));
if (IsNumber(a) && IsNumber(b)) {
PrimExpr ret = tvm::div(node->a, node->b);
if (!ret->IsInstance<tir::DivNode>()) {
return TIR(d, "Div")->Call({a, b});
}
if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) &&
Expand All @@ -312,31 +313,32 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return OperationDoc(OperationDocNode::Kind::kDiv, {a, b});
});

#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \
#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \
.set_dispatch<tir::NodeType>("", \
[](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \
ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \
ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \
if (IsNumber(a) && IsNumber(b)) { \
PrimExpr ret = tvm::NodeFunc(node->a, node->b); \
if (!ret->IsInstance<tir::NodeObj>()) { \
return TIR(d, OpString)->Call({a, b}); \
} \
return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \
});

TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul", kMult);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv", kFloorDiv);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod", kMod);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT", kLt);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE", kLtE);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ", kEq);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE", kNotEq);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT", kGt);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE", kGtE);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And", kAnd);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or", kOr);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, AddNode, add, "Add", kAdd);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, SubNode, sub, "Sub", kSub);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, MulNode, mul, "Mul", kMult);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, FloorDivNode, floordiv, "FloorDiv", kFloorDiv);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, FloorModNode, floormod, "FloorMod", kMod);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, LTNode, less, "LT", kLt);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, LENode, less_equal, "LE", kLtE);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, EQNode, equal, "EQ", kEq);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, NENode, not_equal, "NE", kNotEq);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, GTNode, greater, "GT", kGt);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, GENode, greater_equal, "GE", kGtE);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, AndNode, logical_and, "And", kAnd);
TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, OrNode, logical_or, "Or", kOr);

TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod");
TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min");
Expand Down
39 changes: 32 additions & 7 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,12 @@ def test_cast():


def test_binary_arith():
a = tir.Var("a", "float32")
b = tir.Var("b", "float32")
a = tir.Var("a", "int32")
b = tir.Var("b", "int32")
for op, sign in [
(tir.Add, "+"),
(tir.Sub, "-"),
(tir.Mul, "*"),
(tir.Div, "/"),
(tir.Mod, "truncmod"),
(tir.FloorDiv, "//"),
(tir.FloorMod, "%"),
Expand All @@ -521,21 +520,47 @@ def test_binary_arith():
obj = op(a, b)
if sign.isalpha():
expected = """
a = T.float32()
b = T.float32()
a = T.int32()
b = T.int32()
T.{}(a, b)""".format(
sign
)
else:
expected = """
a = T.float32()
b = T.float32()
a = T.int32()
b = T.int32()
a {} b""".format(
sign
)
_assert_print(obj, expected)


def test_binary_arith_const():
a = tir.IntImm("int64", 3)
b = tir.IntImm("int64", 4)
for op, name in [
(tir.Add, "Add"),
(tir.Sub, "Sub"),
(tir.Mul, "Mul"),
(tir.Div, "Div"),
(tir.Mod, "truncmod"),
(tir.FloorDiv, "FloorDiv"),
(tir.FloorMod, "FloorMod"),
(tir.LT, "LT"),
(tir.LE, "LE"),
(tir.EQ, "EQ"),
(tir.NE, "NE"),
(tir.GT, "GT"),
(tir.GE, "GE"),
]:
obj = op(a, b)
expected = """
T.{}({}, {})""".format(
name, str(a), str(b)
)
_assert_print(obj, expected)


def test_int_div():
a = tir.Var("a", "int32")
b = tir.Var("b", "int32")
Expand Down

0 comments on commit 1fcf769

Please sign in to comment.