From 1fcf769dda345ed6d6182582a2bc62ad816aed50 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 6 Mar 2023 14:46:51 -0800 Subject: [PATCH] constant-folding-free printing --- src/script/printer/tir/expr.cc | 34 ++++++++-------- .../unittest/test_tvmscript_printer_tir.py | 39 +++++++++++++++---- 2 files changed, 50 insertions(+), 23 deletions(-) diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 49325c8609183..b6ee03c7a5d83 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -302,7 +302,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Div node, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); ExprDoc b = d->AsDoc(node->b, p->Attr("b")); - if (IsNumber(a) && IsNumber(b)) { + PrimExpr ret = tvm::div(node->a, node->b); + if (!ret->IsInstance()) { return TIR(d, "Div")->Call({a, b}); } if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) && @@ -312,31 +313,32 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return OperationDoc(OperationDocNode::Kind::kDiv, {a, b}); }); -#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ +#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch("", \ [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ - if (IsNumber(a) && IsNumber(b)) { \ + PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ + if (!ret->IsInstance()) { \ return TIR(d, OpString)->Call({a, b}); \ } \ return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ }); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul", kMult); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv", kFloorDiv); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod", kMod); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT", kLt); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE", kLtE); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ", kEq); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE", kNotEq); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT", kGt); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE", kGtE); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And", kAnd); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or", kOr); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, AddNode, add, "Add", kAdd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, SubNode, sub, "Sub", kSub); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, MulNode, mul, "Mul", kMult); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, FloorDivNode, floordiv, "FloorDiv", kFloorDiv); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, FloorModNode, floormod, "FloorMod", kMod); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, LTNode, less, "LT", kLt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, LENode, less_equal, "LE", kLtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, EQNode, equal, "EQ", kEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, NENode, not_equal, "NE", kNotEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, GTNode, greater, "GT", kGt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, GENode, greater_equal, "GE", kGtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, AndNode, logical_and, "And", kAnd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, OrNode, logical_or, "Or", kOr); TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod"); TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min"); diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 8036fc269a7d5..bcb14f6b4c895 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -501,13 +501,12 @@ def test_cast(): def test_binary_arith(): - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") + a = tir.Var("a", "int32") + b = tir.Var("b", "int32") for op, sign in [ (tir.Add, "+"), (tir.Sub, "-"), (tir.Mul, "*"), - (tir.Div, "/"), (tir.Mod, "truncmod"), (tir.FloorDiv, "//"), (tir.FloorMod, "%"), @@ -521,21 +520,47 @@ def test_binary_arith(): obj = op(a, b) if sign.isalpha(): expected = """ -a = T.float32() -b = T.float32() +a = T.int32() +b = T.int32() T.{}(a, b)""".format( sign ) else: expected = """ -a = T.float32() -b = T.float32() +a = T.int32() +b = T.int32() a {} b""".format( sign ) _assert_print(obj, expected) +def test_binary_arith_const(): + a = tir.IntImm("int64", 3) + b = tir.IntImm("int64", 4) + for op, name in [ + (tir.Add, "Add"), + (tir.Sub, "Sub"), + (tir.Mul, "Mul"), + (tir.Div, "Div"), + (tir.Mod, "truncmod"), + (tir.FloorDiv, "FloorDiv"), + (tir.FloorMod, "FloorMod"), + (tir.LT, "LT"), + (tir.LE, "LE"), + (tir.EQ, "EQ"), + (tir.NE, "NE"), + (tir.GT, "GT"), + (tir.GE, "GE"), + ]: + obj = op(a, b) + expected = """ +T.{}({}, {})""".format( + name, str(a), str(b) + ) + _assert_print(obj, expected) + + def test_int_div(): a = tir.Var("a", "int32") b = tir.Var("b", "int32")