diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index e3dd83743eb7..408c703d54b4 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -1067,7 +1067,7 @@ class FunctionDocNode : public StmtDocNode { /*! \brief Decorators of function. */ Array decorators; /*! \brief The return type of function. */ - ExprDoc return_type{nullptr}; + Optional return_type{NullOpt}; /*! \brief The body of function. */ Array body; @@ -1100,7 +1100,7 @@ class FunctionDoc : public StmtDoc { * \param body The body of function. */ explicit FunctionDoc(IdDoc name, Array args, Array decorators, - ExprDoc return_type, Array body); + Optional return_type, Array body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, FunctionDocNode); }; diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 747ffc42f146..0a5fde89758d 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -439,7 +439,7 @@ class FunctionDoc(StmtDoc): name: IdDoc args: Sequence[AssignDoc] decorators: Sequence[ExprDoc] - return_type: ExprDoc + return_type: Optional[ExprDoc] body: Sequence[StmtDoc] def __init__( @@ -447,7 +447,7 @@ def __init__( name: IdDoc, args: List[AssignDoc], decorators: List[ExprDoc], - return_type: ExprDoc, + return_type: Optional[ExprDoc], body: List[StmtDoc], ): self.__init_handle_by_constructor__( diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index bfff0cfad4fe..2334d1fad511 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -198,7 +198,7 @@ ReturnDoc::ReturnDoc(ExprDoc value) { } FunctionDoc::FunctionDoc(IdDoc name, Array args, Array decorators, - ExprDoc return_type, Array body) { + Optional return_type, Array body) { ObjectPtr n = make_object(); n->name = name; n->args = args; @@ -345,7 +345,7 @@ TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) TVM_REGISTER_NODE_TYPE(FunctionDocNode); TVM_REGISTER_GLOBAL("script.printer.FunctionDoc") .set_body_typed([](IdDoc name, Array args, Array decorators, - ExprDoc return_type, Array body) { + Optional return_type, Array body) { return FunctionDoc(name, args, decorators, return_type, body); }); diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index f44577ff80ad..03a661e52919 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -31,6 +31,114 @@ namespace tvm { namespace script { namespace printer { +/*! + * \brief Operator precedence + * + * This is based on + * https://docs.python.org/3/reference/expressions.html#operator-precedence + */ +enum class ExprPrecedence : int32_t { + /*! \brief Unknown precedence */ + kUnkown = 0, + /*! \brief Lambda Expression */ + kLambda = 1, + /*! \brief Conditional Expression */ + kIfThenElse = 2, + /*! \brief Boolean OR */ + kBooleanOr = 3, + /*! \brief Boolean AND */ + kBooleanAnd = 4, + /*! \brief Boolean NOT */ + kBooleanNot = 5, + /*! \brief Comparisons */ + kComparison = 6, + /*! \brief Bitwise OR */ + kBitwiseOr = 7, + /*! \brief Bitwise XOR */ + kBitwiseXor = 8, + /*! \brief Bitwise AND */ + kBitwiseAnd = 9, + /*! \brief Shift Operators */ + kShift = 10, + /*! \brief Addition and subtraction */ + kAdd = 11, + /*! \brief Multiplication, division, floor division, remainder */ + kMult = 12, + /*! \brief Positive negative and bitwise NOT */ + kUnary = 13, + /*! \brief Exponentiation */ + kExp = 14, + /*! \brief Index access, attribute access, call and atom expression */ + kIdentity = 15, +}; + +#define DOC_PRECEDENCE_ENTRY(RefType, Precedence) \ + { RefType::ContainerType::RuntimeTypeIndex(), ExprPrecedence::Precedence } + +ExprPrecedence GetExprPrecedence(const ExprDoc& doc) { + // Key is the value of OperationDocNode::Kind + static const std::vector op_kind_precedence = []() { + using OpKind = OperationDocNode::Kind; + std::map raw_table = { + {OpKind::kUSub, ExprPrecedence::kUnary}, // + {OpKind::kInvert, ExprPrecedence::kUnary}, // + {OpKind::kAdd, ExprPrecedence::kAdd}, // + {OpKind::kSub, ExprPrecedence::kAdd}, // + {OpKind::kMult, ExprPrecedence::kMult}, // + {OpKind::kDiv, ExprPrecedence::kMult}, // + {OpKind::kFloorDiv, ExprPrecedence::kMult}, // + {OpKind::kMod, ExprPrecedence::kMult}, // + {OpKind::kPow, ExprPrecedence::kExp}, // + {OpKind::kLShift, ExprPrecedence::kShift}, // + {OpKind::kRShift, ExprPrecedence::kShift}, // + {OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd}, // + {OpKind::kBitOr, ExprPrecedence::kBitwiseOr}, // + {OpKind::kBitXor, ExprPrecedence::kBitwiseXor}, // + {OpKind::kLt, ExprPrecedence::kComparison}, // + {OpKind::kLtE, ExprPrecedence::kComparison}, // + {OpKind::kEq, ExprPrecedence::kComparison}, // + {OpKind::kNotEq, ExprPrecedence::kComparison}, // + {OpKind::kGt, ExprPrecedence::kComparison}, // + {OpKind::kGtE, ExprPrecedence::kComparison}, // + {OpKind::kIfThenElse, ExprPrecedence::kIfThenElse}, // + }; + + std::vector table; + table.resize(static_cast(OperationDocNode::Kind::kSpecialEnd) + 1); + + for (const auto& kv : raw_table) { + table[static_cast(kv.first)] = kv.second; + } + + return table; + }(); + + // Key is the type index of Doc + static const std::unordered_map doc_type_precedence = { + DOC_PRECEDENCE_ENTRY(LiteralDoc, kIdentity), // + DOC_PRECEDENCE_ENTRY(IdDoc, kIdentity), // + DOC_PRECEDENCE_ENTRY(AttrAccessDoc, kIdentity), // + DOC_PRECEDENCE_ENTRY(IndexDoc, kIdentity), // + DOC_PRECEDENCE_ENTRY(CallDoc, kIdentity), // + DOC_PRECEDENCE_ENTRY(LambdaDoc, kLambda), // + DOC_PRECEDENCE_ENTRY(TupleDoc, kIdentity), // + DOC_PRECEDENCE_ENTRY(ListDoc, kIdentity), // + DOC_PRECEDENCE_ENTRY(DictDoc, kIdentity), // + }; + + if (const auto* op_doc = doc.as()) { + ExprPrecedence precedence = op_kind_precedence[static_cast(op_doc->kind)]; + ICHECK(precedence != ExprPrecedence::kUnkown) + << "Precedence for operator " << static_cast(op_doc->kind) << " is unknown"; + return precedence; + } else if (doc_type_precedence.find(doc->type_index()) != doc_type_precedence.end()) { + return doc_type_precedence.at(doc->type_index()); + } else { + ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown"; + throw; + } +} + class PythonDocPrinter : public DocPrinter { public: explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces) {} @@ -98,6 +206,42 @@ class PythonDocPrinter : public DocPrinter { } } + /*! + * \brief Print expression and add parenthesis if needed. + */ + void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence, + bool parenthesis_for_same_precedence = false) { + ExprPrecedence doc_precedence = GetExprPrecedence(doc); + if (doc_precedence < parent_precedence || + (parenthesis_for_same_precedence && doc_precedence == parent_precedence)) { + output_ << "("; + PrintDoc(doc); + output_ << ")"; + } else { + PrintDoc(doc); + } + } + + /*! + * \brief Print expression and add parenthesis if doc has lower precedence than parent. + */ + void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent, + bool parenthesis_for_same_precedence = false) { + ExprPrecedence parent_precedence = GetExprPrecedence(parent); + return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence); + } + + /*! + * \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent. + * + * This function should be used to print an child expression that needs to be wrapped + * by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b` + * and the `b` and `c` in `a if b else c`. + */ + void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent) { + PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence*/ true); + } + void MaybePrintCommentInline(const StmtDoc& stmt) { if (stmt->comment.defined()) { const std::string& comment = stmt->comment.value(); @@ -161,12 +305,12 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; } void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { - PrintDoc(doc->value); + PrintChildExpr(doc->value, doc); output_ << "." << doc->name; } void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { - PrintDoc(doc->value); + PrintChildExpr(doc->value, doc); if (doc->indices.size() == 0) { output_ << "[()]"; } else { @@ -226,21 +370,30 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { // Unary Operators ICHECK_EQ(doc->operands.size(), 1); output_ << OperatorToString(doc->kind); - PrintDoc(doc->operands[0]); + PrintChildExpr(doc->operands[0], doc); + } else if (doc->kind == OpKind::kPow) { + // Power operator is different than other binary operators + // It's right-associative and binds less tightly than unary operator on its right. + // https://docs.python.org/3/reference/expressions.html#the-power-operator + // https://docs.python.org/3/reference/expressions.html#operator-precedence + ICHECK_EQ(doc->operands.size(), 2); + PrintChildExprConservatively(doc->operands[0], doc); + output_ << " ** "; + PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary); } else if (doc->kind < OpKind::kBinaryEnd) { // Binary Operator ICHECK_EQ(doc->operands.size(), 2); - PrintDoc(doc->operands[0]); + PrintChildExpr(doc->operands[0], doc); output_ << " " << OperatorToString(doc->kind) << " "; - PrintDoc(doc->operands[1]); + PrintChildExprConservatively(doc->operands[1], doc); } else if (doc->kind == OpKind::kIfThenElse) { ICHECK_EQ(doc->operands.size(), 3) << "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size(); - PrintDoc(doc->operands[1]); + PrintChildExpr(doc->operands[1], doc); output_ << " if "; - PrintDoc(doc->operands[0]); + PrintChildExprConservatively(doc->operands[0], doc); output_ << " else "; - PrintDoc(doc->operands[2]); + PrintChildExprConservatively(doc->operands[2], doc); } else { LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast(doc->kind); throw; @@ -248,7 +401,7 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { } void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { - PrintDoc(doc->callee); + PrintChildExpr(doc->callee, doc); output_ << "("; @@ -285,7 +438,7 @@ void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) { output_ << "lambda "; PrintJoinedDocs(doc->args, ", "); output_ << ": "; - PrintDoc(doc->body); + PrintChildExpr(doc->body, doc); } void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) { @@ -444,8 +597,10 @@ void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) { PrintJoinedDocs(doc->args, ", "); output_ << ")"; - output_ << " -> "; - PrintDoc(doc->return_type); + if (doc->return_type.defined()) { + output_ << " -> "; + PrintDoc(doc->return_type.value()); + } output_ << ":"; diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index 040a82901059..f27bc71b6619 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -21,30 +21,31 @@ import pytest +import tvm from tvm.script.printer.doc import ( - LiteralDoc, - IdDoc, + AssertDoc, + AssignDoc, AttrAccessDoc, - IndexDoc, CallDoc, - OperationKind, - OperationDoc, + ClassDoc, + DictDoc, + ExprStmtDoc, + ForDoc, + FunctionDoc, + IdDoc, + IfDoc, + IndexDoc, LambdaDoc, - TupleDoc, ListDoc, - DictDoc, + LiteralDoc, + OperationDoc, + OperationKind, + ReturnDoc, + ScopeDoc, SliceDoc, StmtBlockDoc, - AssignDoc, - IfDoc, + TupleDoc, WhileDoc, - ForDoc, - ScopeDoc, - ExprStmtDoc, - AssertDoc, - ReturnDoc, - FunctionDoc, - ClassDoc, ) @@ -450,6 +451,13 @@ def test_return_doc(): [IdDoc("test"), IdDoc("test2")], ], ) +@pytest.mark.parametrize( + "return_type", + [ + None, + LiteralDoc(None), + ], +) @pytest.mark.parametrize( "body", [ @@ -458,9 +466,8 @@ def test_return_doc(): [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], ], ) -def test_function_doc(args, decorators, body): +def test_function_doc(args, decorators, return_type, body): name = IdDoc("name") - return_type = LiteralDoc(None) doc = FunctionDoc(name, args, decorators, return_type, body) @@ -504,3 +511,7 @@ def test_stmt_doc_comment(): comment = "test comment" doc.comment = comment assert doc.comment == comment + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 523f62d8b59f..e0905cc14540 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest import itertools +import pytest + +import tvm from tvm.script.printer.doc import ( AssertDoc, AssignDoc, @@ -701,29 +703,32 @@ def test_print_return_doc(value, expected): @pytest.mark.parametrize( - "args, decorators, body, expected", + "args, decorators, return_type, body, expected", [ ( [], [], + None, [], """ - def func() -> None: + def func(): pass """, ), ( [AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int"))], [], + IdDoc("int"), [], """ - def func(x: int) -> None: + def func(x: int) -> int: pass """, ), ( [AssignDoc(IdDoc("x"), rhs=LiteralDoc(1), annotation=IdDoc("int"))], [], + LiteralDoc(None), [], """ def func(x: int = 1) -> None: @@ -733,6 +738,7 @@ def func(x: int = 1) -> None: ( [], [IdDoc("wrap")], + LiteralDoc(None), [], """ @wrap @@ -743,6 +749,7 @@ def func() -> None: ( [], [IdDoc("wrap_outter"), IdDoc("wrap_inner")], + LiteralDoc(None), [], """ @wrap_outter @@ -757,6 +764,7 @@ def func() -> None: AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), annotation=IdDoc("int")), ], [IdDoc("wrap")], + LiteralDoc(None), [], """ @wrap @@ -770,6 +778,7 @@ def func(x: int, y: int = 1) -> None: AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), annotation=IdDoc("int")), ], [IdDoc("wrap")], + LiteralDoc(None), [ AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Add, [IdDoc("x"), LiteralDoc(1)])), AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Sub, [IdDoc("y"), LiteralDoc(1)])), @@ -784,8 +793,8 @@ def func(x: int, y: int = 1) -> None: ], ids=itertools.count(), ) -def test_print_function_doc(args, decorators, body, expected): - doc = FunctionDoc(IdDoc("func"), args, decorators, LiteralDoc(None), body) +def test_print_function_doc(args, decorators, body, return_type, expected): + doc = FunctionDoc(IdDoc("func"), args, decorators, return_type, body) assert to_python_script(doc) == format_script(expected) # test @@ -1038,3 +1047,297 @@ def test_print_invalid_multiline_doc_comment(doc): with pytest.raises(ValueError) as e: to_python_script(doc) assert "cannot have newline" in str(e.value) + + +def generate_expr_precedence_test_cases(): + x = IdDoc("x") + y = IdDoc("y") + z = IdDoc("z") + + def negative(a): + return OperationDoc(OperationKind.USub, [a]) + + def invert(a): + return OperationDoc(OperationKind.Invert, [a]) + + def add(a, b): + return OperationDoc(OperationKind.Add, [a, b]) + + def sub(a, b): + return OperationDoc(OperationKind.Sub, [a, b]) + + def mult(a, b): + return OperationDoc(OperationKind.Mult, [a, b]) + + def div(a, b): + return OperationDoc(OperationKind.Div, [a, b]) + + def mod(a, b): + return OperationDoc(OperationKind.Mod, [a, b]) + + def pow(a, b): + return OperationDoc(OperationKind.Pow, [a, b]) + + def lshift(a, b): + return OperationDoc(OperationKind.LShift, [a, b]) + + def bit_and(a, b): + return OperationDoc(OperationKind.BitAnd, [a, b]) + + def bit_or(a, b): + return OperationDoc(OperationKind.BitOr, [a, b]) + + def bit_xor(a, b): + return OperationDoc(OperationKind.BitXor, [a, b]) + + def lt(a, b): + return OperationDoc(OperationKind.Lt, [a, b]) + + def eq(a, b): + return OperationDoc(OperationKind.Eq, [a, b]) + + def not_eq(a, b): + return OperationDoc(OperationKind.NotEq, [a, b]) + + def if_then_else(a, b, c): + return OperationDoc(OperationKind.IfThenElse, [a, b, c]) + + test_cases = { + "attr-call-index": [ + ( + add(x, y).attr("test"), + "(x + y).test", + ), + ( + add(x, y.attr("test")), + "x + y.test", + ), + ( + x[z].call(y), + "x[z](y)", + ), + ( + x.call(y)[z], + "x(y)[z]", + ), + ( + x.call(y).call(z), + "x(y)(z)", + ), + ( + x.call(y).attr("test"), + "x(y).test", + ), + ( + x.attr("test").call(y), + "x.test(y)", + ), + ( + x.attr("test").attr("test2"), + "x.test.test2", + ), + ( + LambdaDoc([x], x).call(y), + "(lambda x: x)(y)", + ), + ( + add(x, y)[z][add(z, z)].attr("name"), + "(x + y)[z][z + z].name", + ), + ], + "power": [ + ( + pow(pow(x, y), z), + "(x ** y) ** z", + ), + ( + pow(x, pow(y, z)), + "x ** y ** z", + ), + ( + pow(negative(x), negative(y)), + "(-x) ** -y", + ), + ( + pow(add(x, y), add(y, z)), + "(x + y) ** (y + z)", + ), + ], + "unary": [ + ( + invert(negative(y)), + "~-y", + ), + ( + negative(y).attr("test"), + "(-y).test", + ), + ( + negative(y.attr("test")), + "-y.test", + ), + ( + mult(negative(x), negative(y)), + "-x * -y", + ), + ( + negative(add(invert(x), negative(y))), + "-(~x + -y)", + ), + ], + "add-mult": [ + ( + mult(x, mult(y, z)), + "x * (y * z)", + ), + ( + mult(mult(x, y), z), + "x * y * z", + ), + ( + mult(x, add(y, z)), + "x * (y + z)", + ), + ( + mult(add(y, z), x), + "(y + z) * x", + ), + ( + add(x, mod(y, z)), + "x + y % z", + ), + ( + add(mult(y, z), x), + "y * z + x", + ), + ( + add(add(x, y), add(y, z)), + "x + y + (y + z)", + ), + ( + div(add(x, y), add(y, z)), + "(x + y) / (y + z)", + ), + ], + "shift": [ + ( + div(x, lshift(y, z)), + "x / (y << z)", + ), + ( + mult(lshift(y, z), x), + "(y << z) * x", + ), + ( + lshift(x, mult(y, z)), + "x << y * z", + ), + ( + lshift(mult(x, y), z), + "x * y << z", + ), + ( + lshift(mult(x, y), z), + "x * y << z", + ), + ( + lshift(lshift(x, y), z), + "x << y << z", + ), + ( + lshift(x, lshift(y, z)), + "x << (y << z)", + ), + ], + "bitwise": [ + ( + add(bit_or(x, y), bit_or(y, z)), + "(x | y) + (y | z)", + ), + ( + bit_and(bit_or(x, y), bit_or(y, z)), + "(x | y) & (y | z)", + ), + ( + bit_or(bit_and(x, y), bit_and(y, z)), + "x & y | y & z", + ), + ( + bit_and(bit_xor(x, bit_or(y, z)), z), + "(x ^ (y | z)) & z", + ), + ], + "comparison": [ + ( + not_eq(add(x, y), z), + "x + y != z", + ), + ( + eq(pow(x, y), z), + "x ** y == z", + ), + ( + lt(x, div(y, z)), + "x < y / z", + ), + ( + lt(x, if_then_else(y, y, y)), + "x < (y if y else y)", + ), + ], + "if-then-else": [ + ( + if_then_else(x, if_then_else(y, y, y), z), + "y if y else y if x else z", + ), + ( + if_then_else(if_then_else(x, x, x), y, z), + "y if (x if x else x) else z", + ), + ( + if_then_else(x, y, if_then_else(z, z, z)), + "y if x else (z if z else z)", + ), + ( + if_then_else(lt(x, x), add(y, y), mult(z, z)), + "y + y if x < x else z * z", + ), + ( + if_then_else(LambdaDoc([x], x), LambdaDoc([y], y), LambdaDoc([z], z)), + "(lambda y: y) if (lambda x: x) else (lambda z: z)", + ), + ], + "lambda": [ + ( + LambdaDoc([x, y], add(z, z)), + "lambda x, y: z + z", + ), + ( + add(LambdaDoc([x, y], z), z), + "(lambda x, y: z) + z", + ), + ( + LambdaDoc([x, y], add(z, z)).call(x, y), + "(lambda x, y: z + z)(x, y)", + ), + ( + LambdaDoc([x], LambdaDoc([y], z)), + "lambda x: lambda y: z", + ), + ], + } + + return [ + pytest.param(*args, id=f"{group_name}-{i}") + for group_name, cases in test_cases.items() + for i, args in enumerate(cases) + ] + + +@pytest.mark.parametrize("doc, expected", generate_expr_precedence_test_cases()) +def test_expr_precedence(doc, expected): + assert to_python_script(doc) == format_script(expected) + + +if __name__ == "__main__": + tvm.testing.main()