From 29a1b3c969395875940422edf819f80c67f300da Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 6 Mar 2023 14:46:51 -0800 Subject: [PATCH 1/4] [Fix][TVMScript] TVMScript BinOP printing refactor This PR fixes the output for `T.Div(int, int)`. It will print `T.Div(int, int)`, instead of `int / int`, to avoid the integer division ambiguity in parser. And this PR refactors the logic of binary operators printing in TVMScript. The updated TVMScript printer will print the binary operator to avoid constant folding when parsing back. --- python/tvm/script/parser/tir/operation.py | 12 ++--- src/script/printer/tir/expr.cc | 47 +++++++++++------ .../python/unittest/test_inject_ptx_ldg32.py | 2 +- .../unittest/test_meta_schedule_space_cuda.py | 6 +-- ...est_tir_transform_inject_virtual_thread.py | 8 +-- .../unittest/test_tvmscript_printer_tir.py | 52 ++++++++++++++++--- 6 files changed, 90 insertions(+), 37 deletions(-) diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py index f0c04f47cdf6..ed8f07a06369 100644 --- a/python/tvm/script/parser/tir/operation.py +++ b/python/tvm/script/parser/tir/operation.py @@ -46,12 +46,12 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name for i in [0, 1]: # Case 1. binop - r(doc.Add, i, tir.Add) - r(doc.Sub, i, tir.Sub) - r(doc.Mult, i, tir.Mul) - r(doc.Div, i, tir.Div) - r(doc.FloorDiv, i, tir.FloorDiv) - r(doc.Mod, i, tir.FloorMod) + r(doc.Add, i, lambda a, b: a + b) + r(doc.Sub, i, lambda a, b: a - b) + r(doc.Mult, i, lambda a, b: a * b) + r(doc.Div, i, lambda a, b: a / b) + r(doc.FloorDiv, i, lambda a, b: a // b) + r(doc.Mod, i, lambda a, b: a % b) r(doc.LShift, i, lambda a, b: a << b) r(doc.RShift, i, lambda a, b: a >> b) r(doc.BitOr, i, lambda a, b: a | b) diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 02ec269b0e73..003d56800008 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -297,32 +297,47 @@ bool IsNumber(const ExprDoc& e) { return false; } -#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Div node, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); + PrimExpr ret = tvm::div(node->a, node->b); + if (!ret->IsInstance()) { + return TIR(d, "Div")->Call({a, b}); + } + if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) && + (node->b->dtype.is_int() || node->b->dtype.is_uint())) { + return TIR(d, "Div")->Call({a, b}); + } + return OperationDoc(OperationDocNode::Kind::kDiv, {a, b}); + }); + +#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch("", \ [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ - if (IsNumber(a) && IsNumber(b)) { \ + PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ + if (!ret->IsInstance()) { \ return TIR(d, OpString)->Call({a, b}); \ } \ return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ }); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul", kMult); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Div, "Div", kDiv); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv", kFloorDiv); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod", kMod); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT", kLt); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE", kLtE); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ", kEq); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE", kNotEq); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT", kGt); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE", kGtE); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And", kAnd); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or", kOr); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, AddNode, add, "Add", kAdd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, SubNode, sub, "Sub", kSub); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, MulNode, mul, "Mul", kMult); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, FloorDivNode, floordiv, "FloorDiv", kFloorDiv); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, FloorModNode, floormod, "FloorMod", kMod); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, LTNode, less, "LT", kLt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, LENode, less_equal, "LE", kLtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, EQNode, equal, "EQ", kEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, NENode, not_equal, "NE", kNotEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, GTNode, greater, "GT", kGt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, GENode, greater_equal, "GE", kGtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, AndNode, logical_and, "And", kAnd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, OrNode, logical_or, "Or", kOr); TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod"); TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min"); diff --git a/tests/python/unittest/test_inject_ptx_ldg32.py b/tests/python/unittest/test_inject_ptx_ldg32.py index 81c6e89ad921..8e8547c572d0 100644 --- a/tests/python/unittest/test_inject_ptx_ldg32.py +++ b/tests/python/unittest/test_inject_ptx_ldg32.py @@ -32,7 +32,7 @@ def vector_add(A: T.Buffer((16), "float32"), B: T.Buffer((32), "float32")) -> No with T.block(): T.reads(A[0:16]) T.writes(A_local[0:32]) - A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx / 2], T.float32(0), dtype="float32") + A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx // 2], T.float32(0), dtype="float32") B[tx] = A_local[tx] + 1.0 diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index bc674064d1d6..ef662ed5b1e7 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -315,7 +315,7 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 4 + i6_0 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 48 // 16) - v2 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0 + 0) + v2 = T.axis.spatial(18, T.Add(i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0, 0)) v3 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 16 // 8) v4 = T.axis.spatial(4, i8_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8 // 4) v5 = T.axis.spatial(32, i9_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 4) @@ -493,9 +493,9 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, for ax0_ax1_ax2_ax3_fused in T.serial(217): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2 + 0) + v1 = T.axis.spatial(230, T.Add(i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2, 0)) v2 = T.axis.spatial(230, i5_0 * 2 + ax0_ax1_ax2_ax3_fused % 217) - v3 = T.axis.spatial(3, i6_0 + 0) + v3 = T.axis.spatial(3, T.Add(i6_0, 0)) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) T.block_attr({"meta_schedule.cooperative_fetch":2}) diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index d32714938424..beb20fd43ba6 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -182,10 +182,10 @@ def before_func(): def expected_func(): B_data = T.allocate([4], "int32x4", "shared") B = T.Buffer([4], "int32x4", data=B_data, scope="shared") - B[T.Mul(0, 4) / 4] = T.broadcast(0, 4) - B[T.Mul(1, 4) / 4] = T.broadcast(1, 4) - B[T.Mul(2, 4) / 4] = T.broadcast(2, 4) - B[T.Mul(3, 4) / 4] = T.broadcast(3, 4) + B[T.Div(T.Mul(0, 4), 4)] = T.broadcast(0, 4) + B[T.Div(T.Mul(1, 4), 4)] = T.broadcast(1, 4) + B[T.Div(T.Mul(2, 4), 4)] = T.broadcast(2, 4) + B[T.Div(T.Mul(3, 4), 4)] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func) intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index e74f69dcae8b..87ec98e9a266 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -501,13 +501,12 @@ def test_cast(): def test_binary_arith(): - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") + a = tir.Var("a", "int32") + b = tir.Var("b", "int32") for op, sign in [ (tir.Add, "+"), (tir.Sub, "-"), (tir.Mul, "*"), - (tir.Div, "/"), (tir.Mod, "truncmod"), (tir.FloorDiv, "//"), (tir.FloorMod, "%"), @@ -521,21 +520,60 @@ def test_binary_arith(): obj = op(a, b) if sign.isalpha(): expected = """ -a = T.float32() -b = T.float32() +a = T.int32() +b = T.int32() T.{}(a, b)""".format( sign ) else: expected = """ -a = T.float32() -b = T.float32() +a = T.int32() +b = T.int32() a {} b""".format( sign ) _assert_print(obj, expected) +def test_binary_arith_const(): + a = tir.IntImm("int64", 3) + b = tir.IntImm("int64", 4) + for op, name in [ + (tir.Add, "Add"), + (tir.Sub, "Sub"), + (tir.Mul, "Mul"), + (tir.Div, "Div"), + (tir.Mod, "truncmod"), + (tir.FloorDiv, "FloorDiv"), + (tir.FloorMod, "FloorMod"), + (tir.LT, "LT"), + (tir.LE, "LE"), + (tir.EQ, "EQ"), + (tir.NE, "NE"), + (tir.GT, "GT"), + (tir.GE, "GE"), + ]: + obj = op(a, b) + expected = """ +T.{}({}, {})""".format( + name, str(a), str(b) + ) + _assert_print(obj, expected) + + +def test_int_div(): + a = tir.Var("a", "int32") + b = tir.Var("b", "int32") + _assert_print( + tir.Div(a, b), + """ +a = T.int32() +b = T.int32() +T.Div(a, b) +""", + ) + + def test_logical(): a = tir.Var("a", "bool") b = tir.Var("b", "bool") From 156c2565dc5ca22431c672714c449947bf79911c Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 6 Mar 2023 23:14:15 -0800 Subject: [PATCH 2/4] fix unittest --- .../python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index 1cab2554e88f..97ee53f4e409 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -851,7 +851,7 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(512): with T.block("C_reindex_shared"): - v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2 + 0) + v0 = T.axis.spatial(4, T.Add(ax0_0_0_ax1_0_0_fused // 2, 0)) v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) v2 = T.axis.spatial(2, ax2) v3 = T.axis.spatial(1, 0) From c42c26a9c4a12e87b0c00e7ad9cec4894c531d31 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 7 Mar 2023 13:07:48 -0800 Subject: [PATCH 3/4] fix unittest --- .../test_meta_schedule_feature_extractor_per_store_feature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py index 88947962d69d..c62ac788d74b 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py @@ -70,7 +70,7 @@ def main(placeholder: T.Buffer((1, 16, 7, 7, 32), "float32"), placeholder_1: T.B ax4 = T.axis.spatial(512, i0_i1_i2_i3_i4_fused % 512) T.reads(placeholder[0, (ax4 * 49 + ax2 * 7 + ax3) % 25088 // 1568, (ax2 * 7 + ax3) % 49 // 7, ax3 % 7, (ax4 * 49 + ax2 * 7 + ax3) % 1568 // 49], placeholder_1[(ax4 * 49 + ax2 * 7 + ax3) % 25088]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) - T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32") + T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32") # fmt: on From c3cf24d93f6f5b19fb19fa5dd2b53a0bb571f0f7 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 7 Mar 2023 13:40:57 -0800 Subject: [PATCH 4/4] remove overloaded op --- python/tvm/script/parser/tir/operation.py | 28 +++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py index ed8f07a06369..3e120339a6e4 100644 --- a/python/tvm/script/parser/tir/operation.py +++ b/python/tvm/script/parser/tir/operation.py @@ -46,17 +46,17 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name for i in [0, 1]: # Case 1. binop - r(doc.Add, i, lambda a, b: a + b) - r(doc.Sub, i, lambda a, b: a - b) - r(doc.Mult, i, lambda a, b: a * b) - r(doc.Div, i, lambda a, b: a / b) - r(doc.FloorDiv, i, lambda a, b: a // b) - r(doc.Mod, i, lambda a, b: a % b) - r(doc.LShift, i, lambda a, b: a << b) - r(doc.RShift, i, lambda a, b: a >> b) - r(doc.BitOr, i, lambda a, b: a | b) - r(doc.BitXor, i, lambda a, b: a ^ b) - r(doc.BitAnd, i, lambda a, b: a & b) + # doc.Add <-- is overloaded + # doc.Sub <-- is overloaded + # doc.Mult <-- is overloaded + # doc.Div <-- is overloaded + # doc.FloorDiv <-- is overloaded + # doc.Mod <-- is overloaded + # doc.LShift <-- is overloaded + # doc.RShift <-- is overloaded + # doc.BitOr <-- is overloaded + # doc.BitXor <-- is overloaded + # doc.BitAnd <-- is overloaded # doc.MatMult <-- not implemented # doc.Pow <-- not implemented # Case 2. cmpop @@ -75,10 +75,10 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name r(doc.Or, i, _or) for i in [0]: # Case 4. unaryop - r(doc.Invert, i, lambda a: ~a) + # doc.Invert <-- is overloaded r(doc.Not, i, tir.Not) - r(doc.UAdd, i, lambda a: +a) - r(doc.USub, i, lambda a: -a) + # doc.UAdd <-- is overloaded + # doc.USub <-- is overloaded _register_expr_op(tir.PrimExpr)