Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] StmtDoc Printing #12112

Merged
merged 1 commit into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/script/printer/base_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,28 @@ void DocPrinter::PrintDoc(const Doc& doc) {
PrintTypedDoc(GetRef<DictDoc>(doc_node));
} else if (const auto* doc_node = doc.as<SliceDocNode>()) {
PrintTypedDoc(GetRef<SliceDoc>(doc_node));
} else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) {
PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node));
} else if (const auto* doc_node = doc.as<AssignDocNode>()) {
PrintTypedDoc(GetRef<AssignDoc>(doc_node));
} else if (const auto* doc_node = doc.as<IfDocNode>()) {
PrintTypedDoc(GetRef<IfDoc>(doc_node));
} else if (const auto* doc_node = doc.as<WhileDocNode>()) {
PrintTypedDoc(GetRef<WhileDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ForDocNode>()) {
PrintTypedDoc(GetRef<ForDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ScopeDocNode>()) {
PrintTypedDoc(GetRef<ScopeDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ExprStmtDocNode>()) {
PrintTypedDoc(GetRef<ExprStmtDoc>(doc_node));
} else if (const auto* doc_node = doc.as<AssertDocNode>()) {
PrintTypedDoc(GetRef<AssertDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ReturnDocNode>()) {
PrintTypedDoc(GetRef<ReturnDoc>(doc_node));
} else if (const auto* doc_node = doc.as<FunctionDocNode>()) {
PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ClassDocNode>()) {
PrintTypedDoc(GetRef<ClassDoc>(doc_node));
} else {
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
throw;
Expand Down
63 changes: 59 additions & 4 deletions src/script/printer/base_doc_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,22 @@ class DocPrinter {
virtual void PrintTypedDoc(const LiteralDoc& doc) = 0;

/*!
* \brief Virtual method to print a IdDoc
* \brief Virtual method to print an IdDoc
*/
virtual void PrintTypedDoc(const IdDoc& doc) = 0;

/*!
* \brief Virtual method to print a AttrAccessDoc
* \brief Virtual method to print an AttrAccessDoc
*/
virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0;

/*!
* \brief Virtual method to print a IndexDoc
* \brief Virtual method to print an IndexDoc
*/
virtual void PrintTypedDoc(const IndexDoc& doc) = 0;

/*!
* \brief Virtual method to print a OperationDoc
* \brief Virtual method to print an OperationDoc
*/
virtual void PrintTypedDoc(const OperationDoc& doc) = 0;

Expand Down Expand Up @@ -133,6 +133,61 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const SliceDoc& doc) = 0;

/*!
* \brief Virtual method to print a StmtBlockDoc
*/
virtual void PrintTypedDoc(const StmtBlockDoc& doc) = 0;

/*!
* \brief Virtual method to print an AssignDoc
*/
virtual void PrintTypedDoc(const AssignDoc& doc) = 0;

/*!
* \brief Virtual method to print an IfDoc
*/
virtual void PrintTypedDoc(const IfDoc& doc) = 0;

/*!
* \brief Virtual method to print a WhileDoc
*/
virtual void PrintTypedDoc(const WhileDoc& doc) = 0;

/*!
* \brief Virtual method to print a ForDoc
*/
virtual void PrintTypedDoc(const ForDoc& doc) = 0;

/*!
* \brief Virtual method to print a ScopeDoc
*/
virtual void PrintTypedDoc(const ScopeDoc& doc) = 0;

/*!
* \brief Virtual method to print an ExprStmtDoc
*/
virtual void PrintTypedDoc(const ExprStmtDoc& doc) = 0;

/*!
* \brief Virtual method to print an AssertDoc
*/
virtual void PrintTypedDoc(const AssertDoc& doc) = 0;

/*!
* \brief Virtual method to print a ReturnDoc
*/
virtual void PrintTypedDoc(const ReturnDoc& doc) = 0;

/*!
* \brief Virtual method to print a FunctionDoc
*/
virtual void PrintTypedDoc(const FunctionDoc& doc) = 0;

/*!
* \brief Virtual method to print a ClassDoc
*/
virtual void PrintTypedDoc(const ClassDoc& doc) = 0;

/*!
* \brief Increase the indent level of any content to be
* printed after this call
Expand Down
212 changes: 211 additions & 1 deletion src/script/printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/doc.h>

#include <algorithm>
#include <string>

#include "../../support/str_escape.h"
#include "../../support/utils.h"
#include "./base_doc_printer.h"

namespace tvm {
Expand All @@ -45,8 +49,21 @@ class PythonDocPrinter : public DocPrinter {
void PrintTypedDoc(const DictDoc& doc) final;
void PrintTypedDoc(const TupleDoc& doc) final;
void PrintTypedDoc(const SliceDoc& doc) final;
void PrintTypedDoc(const StmtBlockDoc& doc) final;
void PrintTypedDoc(const AssignDoc& doc) final;
void PrintTypedDoc(const IfDoc& doc) final;
void PrintTypedDoc(const WhileDoc& doc) final;
void PrintTypedDoc(const ForDoc& doc) final;
void PrintTypedDoc(const ExprStmtDoc& doc) final;
void PrintTypedDoc(const AssertDoc& doc) final;
void PrintTypedDoc(const ReturnDoc& doc) final;
void PrintTypedDoc(const ScopeDoc& doc) final;
void PrintTypedDoc(const FunctionDoc& doc) final;
void PrintTypedDoc(const ClassDoc& doc) final;

private:
void NewLineWithoutIndent() { output_ << "\n"; }

template <typename DocType>
void PrintJoinedDocs(const Array<DocType>& docs, const std::string& separator) {
bool is_first = true;
Expand All @@ -59,6 +76,65 @@ class PythonDocPrinter : public DocPrinter {
PrintDoc(doc);
}
}

void PrintIndentedBlock(const Array<StmtDoc>& docs) {
IncreaseIndent();
for (const StmtDoc& d : docs) {
NewLine();
PrintDoc(d);
}
if (docs.empty()) {
NewLine();
output_ << "pass";
}
DecreaseIndent();
}

void PrintDecorators(const Array<ExprDoc>& decorators) {
for (const ExprDoc& decorator : decorators) {
output_ << "@";
PrintDoc(decorator);
NewLine();
}
}

void MaybePrintCommentInline(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
const std::string& comment = stmt->comment.value();
bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end();
CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey()
<< " cannot have newline.";
output_ << " # " << comment;
}
}

void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n');
for (const std::string& line : comment_lines) {
output_ << "# " << line;
NewLine();
}
}
}

void PrintBlockComment(const String& comment) {
IncreaseIndent();
NewLine() << "\"\"\"";

std::vector<std::string> comment_lines = support::Split(comment, '\n');
for (const std::string& line : comment_lines) {
if (line.empty()) {
// No indentation on empty line
output_ << "\n";
} else {
NewLine() << line;
}
}

NewLine() << "\"\"\"";
DecreaseIndent();
}
};

void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
Expand Down Expand Up @@ -260,6 +336,140 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
}
}

void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
for (const StmtDoc& stmt : doc->stmts) {
PrintDoc(stmt);
NewLine();
}
}

void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) {
PrintJoinedDocs(tuple_doc->elements, ", ");
} else {
PrintDoc(doc->lhs);
}

if (doc->annotation) {
output_ << ": ";
PrintDoc(doc->annotation.value());
}
if (doc->rhs) {
output_ << " = ";
PrintDoc(doc->rhs.value());
}
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "if ";
PrintDoc(doc->predicate);
output_ << ":";

PrintIndentedBlock(doc->then_branch);

if (!doc->else_branch.empty()) {
NewLine();
output_ << "else:";
PrintIndentedBlock(doc->else_branch);
}
}

void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "while ";
PrintDoc(doc->predicate);
output_ << ":";

PrintIndentedBlock(doc->body);
}

void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "for ";
PrintDoc(doc->lhs);
output_ << " in ";
PrintDoc(doc->rhs);
output_ << ":";

PrintIndentedBlock(doc->body);
}

void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "with ";
PrintDoc(doc->rhs);
if (doc->lhs != nullptr) {
output_ << " as ";
PrintDoc(doc->lhs.value());
}
output_ << ":";

PrintIndentedBlock(doc->body);
}

void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) {
PrintDoc(doc->expr);
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) {
output_ << "assert ";
PrintDoc(doc->test);
if (doc->msg.defined()) {
output_ << ", ";
PrintDoc(doc->msg.value());
}
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) {
output_ << "return ";
PrintDoc(doc->value);
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
for (const AssignDoc& arg_doc : doc->args) {
ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them.";
}

PrintDecorators(doc->decorators);

output_ << "def ";
PrintDoc(doc->name);

output_ << "(";
PrintJoinedDocs(doc->args, ", ");
output_ << ")";

output_ << " -> ";
PrintDoc(doc->return_type);
Comment on lines +447 to +448
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we consider the case where return_type is not given?


output_ << ":";

if (doc->comment.defined()) {
PrintBlockComment(doc->comment.value());
}
PrintIndentedBlock(doc->body);
NewLineWithoutIndent();
}

void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
PrintDecorators(doc->decorators);

output_ << "class ";
PrintDoc(doc->name);
output_ << ":";

if (doc->comment.defined()) {
PrintBlockComment(doc->comment.value());
}
PrintIndentedBlock(doc->body);
NewLineWithoutIndent();
}

String DocToPythonScript(Doc doc, int indent_spaces) {
PythonDocPrinter printer(indent_spaces);
printer.Append(doc);
Expand Down
Loading