Skip to content

Commit

Permalink
[TVMScript] Comments and docstrings printing (#13839)
Browse files Browse the repository at this point in the history
This PR introduces the `CommentDoc` for comments printing and `DocStringDoc` for docstring printing. It enables to add free comments and docstring as `stmt` in printing, e.g.
```python
# comment 1
# comment 2
"""
docstring 1
docstring 2
"""
```
The free here means to not be bound to any `stmt`, but acts as a single `stmt`, similar to `ExprStmtDoc` for `ExprDoc`. 

This PR also introduces an example for the `CommentDoc`, as follow up of #13819.
In the old printer, we always print a `# with T.block("root"):`, when there is an implicit root block skipped when printing. For example,
```
@T.prim_func
def main():
  # with T.block("root"):
  a = T.alloc_buffer((128, 128))
  for i, j in T.grid(128, 128):
    with T.block(""):
      ...
```
We bring this syntax reminder back in this PR.
In addition, we introduce a field of `ir_usage` and `print_headers` into the printer configuration, to support the printing of headers for `IRModule` and `PrimFunc`. For example,

```python
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module():
  @T.prim_func
  def func():
    ...
```
  • Loading branch information
cyx-6 authored Jan 26, 2023
1 parent 239edb5 commit 697fdb2
Show file tree
Hide file tree
Showing 16 changed files with 261 additions and 11 deletions.
44 changes: 44 additions & 0 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,50 @@ class ClassDoc : public StmtDoc {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode);
};

/*!
* \brief Doc that represents comment.
*
* \sa CommentDoc
*/
class CommentDocNode : public StmtDocNode {
public:
static constexpr const char* _type_key = "script.printer.CommentDoc";
TVM_DECLARE_FINAL_OBJECT_INFO(CommentDocNode, StmtDocNode);
};

/*!
* \brief Reference type of CommentDocNode.
*
* \sa CommentDocNode
*/
class CommentDoc : public StmtDoc {
public:
explicit CommentDoc(String comment);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CommentDoc, StmtDoc, CommentDocNode);
};

/*!
* \brief Doc that represents docstring.
*
* \sa DocStringDoc
*/
class DocStringDocNode : public StmtDocNode {
public:
static constexpr const char* _type_key = "script.printer.DocStringDoc";
TVM_DECLARE_FINAL_OBJECT_INFO(DocStringDocNode, StmtDocNode);
};

/*!
* \brief Reference type of DocStringDocNode.
*
* \sa DocStringDocNode
*/
class DocStringDoc : public StmtDoc {
public:
explicit DocStringDoc(String docs);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DocStringDoc, StmtDoc, DocStringDocNode);
};

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/ir_docsifier_functor.h>

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -148,6 +149,8 @@ class IRDocsifierNode : public Object {
std::unordered_set<String> defined_names;
/*! \brief Common prefixes of variable usages */
std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
/*! \brief The IR usages for headers printing */
std::unordered_set<std::string> ir_usage;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("frames", &frames);
Expand All @@ -156,6 +159,7 @@ class IRDocsifierNode : public Object {
// `obj2info` is not visited
// `defined_names` is not visited
// `common_prefix` is not visited
// `ir_usage` is not visited
}

static constexpr const char* _type_key = "script.printer.IRDocsifier";
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/script/printer/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,23 @@ def __init__(self, name: IdDoc, decorators: List[ExprDoc], body: List[StmtDoc]):
decorators,
body,
)


@register_object("script.printer.CommentDoc")
class CommentDoc(StmtDoc):
"""Doc that represents comment."""

def __init__(self, comment: str):
self.__init_handle_by_constructor__(
_ffi_api.CommentDoc, comment # type: ignore # pylint: disable=no-member
)


@register_object("script.printer.DocStringDoc")
class DocStringDoc(StmtDoc):
"""Doc that represents docstring."""

def __init__(self, docs: str):
self.__init_handle_by_constructor__(
_ffi_api.DocStringDoc, docs # type: ignore # pylint: disable=no-member
)
22 changes: 22 additions & 0 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,18 @@ ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) {
this->data_ = std::move(n);
}

CommentDoc::CommentDoc(String comment) {
ObjectPtr<CommentDocNode> n = make_object<CommentDocNode>();
n->comment = comment;
this->data_ = std::move(n);
}

DocStringDoc::DocStringDoc(String docs) {
ObjectPtr<DocStringDocNode> n = make_object<DocStringDocNode>();
n->comment = docs;
this->data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(DocNode);
TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
.set_body_typed([](Doc doc, Array<ObjectPath> source_paths) {
Expand Down Expand Up @@ -365,6 +377,16 @@ TVM_REGISTER_GLOBAL("script.printer.ClassDoc")
return ClassDoc(name, decorators, body);
});

TVM_REGISTER_NODE_TYPE(CommentDocNode);
TVM_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String comment) {
return CommentDoc(comment);
});

TVM_REGISTER_NODE_TYPE(DocStringDocNode);
TVM_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String docs) {
return DocStringDoc(docs);
});

} // namespace printer
} // namespace script
} // namespace tvm
4 changes: 4 additions & 0 deletions src/script/printer/doc_printer/base_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ void DocPrinter::PrintDoc(const Doc& doc) {
PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ClassDocNode>()) {
PrintTypedDoc(GetRef<ClassDoc>(doc_node));
} else if (const auto* doc_node = doc.as<CommentDocNode>()) {
PrintTypedDoc(GetRef<CommentDoc>(doc_node));
} else if (const auto* doc_node = doc.as<DocStringDocNode>()) {
PrintTypedDoc(GetRef<DocStringDoc>(doc_node));
} else {
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
throw;
Expand Down
10 changes: 10 additions & 0 deletions src/script/printer/doc_printer/base_doc_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const ClassDoc& doc) = 0;

/*!
* \brief Virtual method to print a CommentDoc
*/
virtual void PrintTypedDoc(const CommentDoc& doc) = 0;

/*!
* \brief Virtual method to print a DocStringDoc
*/
virtual void PrintTypedDoc(const DocStringDoc& doc) = 0;

/*!
* \brief Increase the indent level of any content to be
* printed after this call
Expand Down
34 changes: 28 additions & 6 deletions src/script/printer/doc_printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class PythonDocPrinter : public DocPrinter {
void PrintTypedDoc(const ScopeDoc& doc) final;
void PrintTypedDoc(const FunctionDoc& doc) final;
void PrintTypedDoc(const ClassDoc& doc) final;
void PrintTypedDoc(const CommentDoc& doc) final;
void PrintTypedDoc(const DocStringDoc& doc) final;

private:
void NewLineWithoutIndent() { output_ << "\n"; }
Expand Down Expand Up @@ -253,11 +255,19 @@ class PythonDocPrinter : public DocPrinter {
}
}

void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) {
if (stmt->comment.defined()) {
std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n');
bool first_line = true;
for (const std::string& line : comment_lines) {
output_ << "# " << line;
if (first_line) {
output_ << "# " << line;
first_line = false;
} else {
NewLine() << "# " << line;
}
}
if (new_line) {
NewLine();
}
}
Expand Down Expand Up @@ -523,7 +533,7 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
}

void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
MaybePrintCommentWithNewLine(doc);
MaybePrintCommenMultiLines(doc, true);
output_ << "if ";
PrintDoc(doc->predicate);
output_ << ":";
Expand All @@ -538,7 +548,7 @@ void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
}

void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
MaybePrintCommentWithNewLine(doc);
MaybePrintCommenMultiLines(doc, true);
output_ << "while ";
PrintDoc(doc->predicate);
output_ << ":";
Expand All @@ -547,7 +557,7 @@ void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
}

void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
MaybePrintCommentWithNewLine(doc);
MaybePrintCommenMultiLines(doc, true);
output_ << "for ";
if (const auto* tuple = doc->lhs.as<TupleDocNode>()) {
if (tuple->elements.size() == 1) {
Expand All @@ -567,7 +577,7 @@ void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
}

void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
MaybePrintCommentWithNewLine(doc);
MaybePrintCommenMultiLines(doc, true);
output_ << "with ";
PrintDoc(doc->rhs);
if (doc->lhs != nullptr) {
Expand Down Expand Up @@ -642,6 +652,18 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
NewLineWithoutIndent();
}

void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) {
if (doc->comment.defined()) {
MaybePrintCommenMultiLines(doc, false);
}
}

void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) {
if (doc->comment.defined() && !doc->comment.value().empty()) {
output_ << "\"\"\"" << doc->comment.value() << "\"\"\"";
}
}

String DocToPythonScript(Doc doc, const PrinterConfig& cfg) {
if (cfg->num_context_lines < 0) {
cfg->num_context_lines = std::numeric_limits<int32_t>::max();
Expand Down
3 changes: 2 additions & 1 deletion src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) {
return s.value();
}
}
Doc doc = IRDocsifier(cfg)->AsDoc(mod, ObjectPath::Root());
IRDocsifier d(cfg);
Doc doc = HeaderWrapper(d, d->AsDoc(mod, ObjectPath::Root()));
return DocToPythonScript(doc, cfg);
}

Expand Down
1 change: 1 addition & 0 deletions src/script/printer/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace printer {

/*! \brief Creates the IR common prefix, which is by default `I` */
inline ExprDoc IR(const IRDocsifier& d, const String& attr) {
d->ir_usage.insert("ir");
return IdDoc(d->cfg->ir_prefix)->Attr(attr);
}

Expand Down
4 changes: 3 additions & 1 deletion src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (implicit_root_block) {
tir::Block root_block = implicit_root_block.value();
ObjectPath root_block_p = p->Attr("body")->Attr("body");
(*frame)->stmts.push_back(CommentDoc("with T.block(\"root\"):"));
// Handle root block `alloc_buffer`
for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) {
tir::Buffer buffer = root_block->alloc_buffers[i];
Expand Down Expand Up @@ -181,7 +182,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});

std::string ReprPrintPrimFunc(const ObjectRef& obj, const PrinterConfig& cfg) {
Doc doc = IRDocsifier(cfg)->AsDoc(obj, ObjectPath::Root());
IRDocsifier d(cfg);
Doc doc = HeaderWrapper(d, d->AsDoc(obj, ObjectPath::Root()));
return DocToPythonScript(doc, cfg);
}

Expand Down
1 change: 1 addition & 0 deletions src/script/printer/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class TIRFrame : public Frame {

/*! \brief Creates the TIR common prefix, which is by default `T` */
inline ExprDoc TIR(const IRDocsifier& d, const String& attr) {
d->ir_usage.insert("tir");
return IdDoc(d->cfg->tir_prefix)->Attr(attr);
}

Expand Down
20 changes: 20 additions & 0 deletions src/script/printer/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,26 @@ inline std::string DType2Str(const runtime::DataType& dtype) {
return dtype.is_void() ? "void" : runtime::DLDataType2String(dtype);
}

/*! \brief Add headers as comments to doc if needed */
inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) {
if (d->ir_usage.size()) {
Array<StmtDoc> stmts;
if (d->ir_usage.count("ir")) {
stmts.push_back(CommentDoc("from tvm.script import ir as " + d->cfg->ir_prefix));
}
if (d->ir_usage.count("tir")) {
stmts.push_back(CommentDoc("from tvm.script import tir as " + d->cfg->tir_prefix));
}
if (d->ir_usage.count("relax")) {
stmts.push_back(CommentDoc("from tvm.script import relax as " + d->cfg->relax_prefix));
}
stmts.push_back(CommentDoc(""));
stmts.push_back(Downcast<StmtDoc>(doc));
return StmtBlockDoc(stmts);
}
return doc;
}

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
AttrAccessDoc,
CallDoc,
ClassDoc,
CommentDoc,
DictDoc,
DocStringDoc,
ExprStmtDoc,
ForDoc,
FunctionDoc,
Expand Down Expand Up @@ -505,6 +507,32 @@ def test_class_doc(decorators, body):
assert list(doc.body) == body


@pytest.mark.parametrize(
"comment",
[
"",
"test comment 1",
"test comment 1\ntest comment 1",
],
)
def test_comment_doc(comment):
doc = CommentDoc(comment)
assert doc.comment == comment


@pytest.mark.parametrize(
"comment",
[
"",
"test comment 1",
"test comment 1\ntest comment 1",
],
)
def test_doc_string_doc(comment):
doc = DocStringDoc(comment)
assert doc.comment == comment


def test_stmt_doc_comment():
doc = ExprStmtDoc(IdDoc("x"))
assert doc.comment is None
Expand Down
3 changes: 3 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def test_ir_module():
_assert_print(
mod,
"""
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
Expand Down
Loading

0 comments on commit 697fdb2

Please sign in to comment.