diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py index f0c04f47cdf6..3e120339a6e4 100644 --- a/python/tvm/script/parser/tir/operation.py +++ b/python/tvm/script/parser/tir/operation.py @@ -46,17 +46,17 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name for i in [0, 1]: # Case 1. binop - r(doc.Add, i, tir.Add) - r(doc.Sub, i, tir.Sub) - r(doc.Mult, i, tir.Mul) - r(doc.Div, i, tir.Div) - r(doc.FloorDiv, i, tir.FloorDiv) - r(doc.Mod, i, tir.FloorMod) - r(doc.LShift, i, lambda a, b: a << b) - r(doc.RShift, i, lambda a, b: a >> b) - r(doc.BitOr, i, lambda a, b: a | b) - r(doc.BitXor, i, lambda a, b: a ^ b) - r(doc.BitAnd, i, lambda a, b: a & b) + # doc.Add <-- is overloaded + # doc.Sub <-- is overloaded + # doc.Mult <-- is overloaded + # doc.Div <-- is overloaded + # doc.FloorDiv <-- is overloaded + # doc.Mod <-- is overloaded + # doc.LShift <-- is overloaded + # doc.RShift <-- is overloaded + # doc.BitOr <-- is overloaded + # doc.BitXor <-- is overloaded + # doc.BitAnd <-- is overloaded # doc.MatMult <-- not implemented # doc.Pow <-- not implemented # Case 2. cmpop @@ -75,10 +75,10 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name r(doc.Or, i, _or) for i in [0]: # Case 4. unaryop - r(doc.Invert, i, lambda a: ~a) + # doc.Invert <-- is overloaded r(doc.Not, i, tir.Not) - r(doc.UAdd, i, lambda a: +a) - r(doc.USub, i, lambda a: -a) + # doc.UAdd <-- is overloaded + # doc.USub <-- is overloaded _register_expr_op(tir.PrimExpr) diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 02ec269b0e73..003d56800008 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -297,32 +297,47 @@ bool IsNumber(const ExprDoc& e) { return false; } -#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Div node, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); + PrimExpr ret = tvm::div(node->a, node->b); + if (!ret->IsInstance()) { + return TIR(d, "Div")->Call({a, b}); + } + if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) && + (node->b->dtype.is_int() || node->b->dtype.is_uint())) { + return TIR(d, "Div")->Call({a, b}); + } + return OperationDoc(OperationDocNode::Kind::kDiv, {a, b}); + }); + +#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch("", \ [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ - if (IsNumber(a) && IsNumber(b)) { \ + PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ + if (!ret->IsInstance()) { \ return TIR(d, OpString)->Call({a, b}); \ } \ return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ }); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul", kMult); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Div, "Div", kDiv); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv", kFloorDiv); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod", kMod); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT", kLt); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE", kLtE); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ", kEq); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE", kNotEq); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT", kGt); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE", kGtE); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And", kAnd); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or", kOr); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, AddNode, add, "Add", kAdd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, SubNode, sub, "Sub", kSub); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, MulNode, mul, "Mul", kMult); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, FloorDivNode, floordiv, "FloorDiv", kFloorDiv); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, FloorModNode, floormod, "FloorMod", kMod); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, LTNode, less, "LT", kLt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, LENode, less_equal, "LE", kLtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, EQNode, equal, "EQ", kEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, NENode, not_equal, "NE", kNotEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, GTNode, greater, "GT", kGt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, GENode, greater_equal, "GE", kGtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, AndNode, logical_and, "And", kAnd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, OrNode, logical_or, "Or", kOr); TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod"); TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min"); diff --git a/tests/python/unittest/test_inject_ptx_ldg32.py b/tests/python/unittest/test_inject_ptx_ldg32.py index 81c6e89ad921..8e8547c572d0 100644 --- a/tests/python/unittest/test_inject_ptx_ldg32.py +++ b/tests/python/unittest/test_inject_ptx_ldg32.py @@ -32,7 +32,7 @@ def vector_add(A: T.Buffer((16), "float32"), B: T.Buffer((32), "float32")) -> No with T.block(): T.reads(A[0:16]) T.writes(A_local[0:32]) - A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx / 2], T.float32(0), dtype="float32") + A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx // 2], T.float32(0), dtype="float32") B[tx] = A_local[tx] + 1.0 diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py index 88947962d69d..c62ac788d74b 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py @@ -70,7 +70,7 @@ def main(placeholder: T.Buffer((1, 16, 7, 7, 32), "float32"), placeholder_1: T.B ax4 = T.axis.spatial(512, i0_i1_i2_i3_i4_fused % 512) T.reads(placeholder[0, (ax4 * 49 + ax2 * 7 + ax3) % 25088 // 1568, (ax2 * 7 + ax3) % 49 // 7, ax3 % 7, (ax4 * 49 + ax2 * 7 + ax3) % 1568 // 49], placeholder_1[(ax4 * 49 + ax2 * 7 + ax3) % 25088]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) - T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32") + T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32") # fmt: on diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index 1cab2554e88f..97ee53f4e409 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -851,7 +851,7 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(512): with T.block("C_reindex_shared"): - v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2 + 0) + v0 = T.axis.spatial(4, T.Add(ax0_0_0_ax1_0_0_fused // 2, 0)) v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) v2 = T.axis.spatial(2, ax2) v3 = T.axis.spatial(1, 0) diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index bc674064d1d6..ef662ed5b1e7 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -315,7 +315,7 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 4 + i6_0 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 48 // 16) - v2 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0 + 0) + v2 = T.axis.spatial(18, T.Add(i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0, 0)) v3 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 16 // 8) v4 = T.axis.spatial(4, i8_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8 // 4) v5 = T.axis.spatial(32, i9_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 4) @@ -493,9 +493,9 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, for ax0_ax1_ax2_ax3_fused in T.serial(217): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2 + 0) + v1 = T.axis.spatial(230, T.Add(i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2, 0)) v2 = T.axis.spatial(230, i5_0 * 2 + ax0_ax1_ax2_ax3_fused % 217) - v3 = T.axis.spatial(3, i6_0 + 0) + v3 = T.axis.spatial(3, T.Add(i6_0, 0)) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) T.block_attr({"meta_schedule.cooperative_fetch":2}) diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index d32714938424..beb20fd43ba6 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -182,10 +182,10 @@ def before_func(): def expected_func(): B_data = T.allocate([4], "int32x4", "shared") B = T.Buffer([4], "int32x4", data=B_data, scope="shared") - B[T.Mul(0, 4) / 4] = T.broadcast(0, 4) - B[T.Mul(1, 4) / 4] = T.broadcast(1, 4) - B[T.Mul(2, 4) / 4] = T.broadcast(2, 4) - B[T.Mul(3, 4) / 4] = T.broadcast(3, 4) + B[T.Div(T.Mul(0, 4), 4)] = T.broadcast(0, 4) + B[T.Div(T.Mul(1, 4), 4)] = T.broadcast(1, 4) + B[T.Div(T.Mul(2, 4), 4)] = T.broadcast(2, 4) + B[T.Div(T.Mul(3, 4), 4)] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func) intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index e74f69dcae8b..87ec98e9a266 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -501,13 +501,12 @@ def test_cast(): def test_binary_arith(): - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") + a = tir.Var("a", "int32") + b = tir.Var("b", "int32") for op, sign in [ (tir.Add, "+"), (tir.Sub, "-"), (tir.Mul, "*"), - (tir.Div, "/"), (tir.Mod, "truncmod"), (tir.FloorDiv, "//"), (tir.FloorMod, "%"), @@ -521,21 +520,60 @@ def test_binary_arith(): obj = op(a, b) if sign.isalpha(): expected = """ -a = T.float32() -b = T.float32() +a = T.int32() +b = T.int32() T.{}(a, b)""".format( sign ) else: expected = """ -a = T.float32() -b = T.float32() +a = T.int32() +b = T.int32() a {} b""".format( sign ) _assert_print(obj, expected) +def test_binary_arith_const(): + a = tir.IntImm("int64", 3) + b = tir.IntImm("int64", 4) + for op, name in [ + (tir.Add, "Add"), + (tir.Sub, "Sub"), + (tir.Mul, "Mul"), + (tir.Div, "Div"), + (tir.Mod, "truncmod"), + (tir.FloorDiv, "FloorDiv"), + (tir.FloorMod, "FloorMod"), + (tir.LT, "LT"), + (tir.LE, "LE"), + (tir.EQ, "EQ"), + (tir.NE, "NE"), + (tir.GT, "GT"), + (tir.GE, "GE"), + ]: + obj = op(a, b) + expected = """ +T.{}({}, {})""".format( + name, str(a), str(b) + ) + _assert_print(obj, expected) + + +def test_int_div(): + a = tir.Var("a", "int32") + b = tir.Var("b", "int32") + _assert_print( + tir.Div(a, b), + """ +a = T.int32() +b = T.int32() +T.Div(a, b) +""", + ) + + def test_logical(): a = tir.Var("a", "bool") b = tir.Var("b", "bool")