Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TVMScript] StmtDoc Printing (apache#12112)
Browse files Browse the repository at this point in the history
This PR addes:

- StmtDoc Printing in PythonDocPrinter

Tracking issue: apache#11912
  • Loading branch information
yelite authored and xinetzone committed Nov 25, 2022
1 parent 204b256 commit 6a94245
Show file tree
Hide file tree
Showing 4 changed files with 906 additions and 6 deletions.
22 changes: 22 additions & 0 deletions src/script/printer/base_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,28 @@ void DocPrinter::PrintDoc(const Doc& doc) {
PrintTypedDoc(GetRef<DictDoc>(doc_node));
} else if (const auto* doc_node = doc.as<SliceDocNode>()) {
PrintTypedDoc(GetRef<SliceDoc>(doc_node));
} else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) {
PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node));
} else if (const auto* doc_node = doc.as<AssignDocNode>()) {
PrintTypedDoc(GetRef<AssignDoc>(doc_node));
} else if (const auto* doc_node = doc.as<IfDocNode>()) {
PrintTypedDoc(GetRef<IfDoc>(doc_node));
} else if (const auto* doc_node = doc.as<WhileDocNode>()) {
PrintTypedDoc(GetRef<WhileDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ForDocNode>()) {
PrintTypedDoc(GetRef<ForDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ScopeDocNode>()) {
PrintTypedDoc(GetRef<ScopeDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ExprStmtDocNode>()) {
PrintTypedDoc(GetRef<ExprStmtDoc>(doc_node));
} else if (const auto* doc_node = doc.as<AssertDocNode>()) {
PrintTypedDoc(GetRef<AssertDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ReturnDocNode>()) {
PrintTypedDoc(GetRef<ReturnDoc>(doc_node));
} else if (const auto* doc_node = doc.as<FunctionDocNode>()) {
PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ClassDocNode>()) {
PrintTypedDoc(GetRef<ClassDoc>(doc_node));
} else {
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
throw;
Expand Down
63 changes: 59 additions & 4 deletions src/script/printer/base_doc_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,22 @@ class DocPrinter {
virtual void PrintTypedDoc(const LiteralDoc& doc) = 0;

/*!
* \brief Virtual method to print a IdDoc
* \brief Virtual method to print an IdDoc
*/
virtual void PrintTypedDoc(const IdDoc& doc) = 0;

/*!
* \brief Virtual method to print a AttrAccessDoc
* \brief Virtual method to print an AttrAccessDoc
*/
virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0;

/*!
* \brief Virtual method to print a IndexDoc
* \brief Virtual method to print an IndexDoc
*/
virtual void PrintTypedDoc(const IndexDoc& doc) = 0;

/*!
* \brief Virtual method to print a OperationDoc
* \brief Virtual method to print an OperationDoc
*/
virtual void PrintTypedDoc(const OperationDoc& doc) = 0;

Expand Down Expand Up @@ -133,6 +133,61 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const SliceDoc& doc) = 0;

/*!
* \brief Virtual method to print a StmtBlockDoc
*/
virtual void PrintTypedDoc(const StmtBlockDoc& doc) = 0;

/*!
* \brief Virtual method to print an AssignDoc
*/
virtual void PrintTypedDoc(const AssignDoc& doc) = 0;

/*!
* \brief Virtual method to print an IfDoc
*/
virtual void PrintTypedDoc(const IfDoc& doc) = 0;

/*!
* \brief Virtual method to print a WhileDoc
*/
virtual void PrintTypedDoc(const WhileDoc& doc) = 0;

/*!
* \brief Virtual method to print a ForDoc
*/
virtual void PrintTypedDoc(const ForDoc& doc) = 0;

/*!
* \brief Virtual method to print a ScopeDoc
*/
virtual void PrintTypedDoc(const ScopeDoc& doc) = 0;

/*!
* \brief Virtual method to print an ExprStmtDoc
*/
virtual void PrintTypedDoc(const ExprStmtDoc& doc) = 0;

/*!
* \brief Virtual method to print an AssertDoc
*/
virtual void PrintTypedDoc(const AssertDoc& doc) = 0;

/*!
* \brief Virtual method to print a ReturnDoc
*/
virtual void PrintTypedDoc(const ReturnDoc& doc) = 0;

/*!
* \brief Virtual method to print a FunctionDoc
*/
virtual void PrintTypedDoc(const FunctionDoc& doc) = 0;

/*!
* \brief Virtual method to print a ClassDoc
*/
virtual void PrintTypedDoc(const ClassDoc& doc) = 0;

/*!
* \brief Increase the indent level of any content to be
* printed after this call
Expand Down
212 changes: 211 additions & 1 deletion src/script/printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/doc.h>

#include <algorithm>
#include <string>

#include "../../support/str_escape.h"
#include "../../support/utils.h"
#include "./base_doc_printer.h"

namespace tvm {
Expand All @@ -45,8 +49,21 @@ class PythonDocPrinter : public DocPrinter {
void PrintTypedDoc(const DictDoc& doc) final;
void PrintTypedDoc(const TupleDoc& doc) final;
void PrintTypedDoc(const SliceDoc& doc) final;
void PrintTypedDoc(const StmtBlockDoc& doc) final;
void PrintTypedDoc(const AssignDoc& doc) final;
void PrintTypedDoc(const IfDoc& doc) final;
void PrintTypedDoc(const WhileDoc& doc) final;
void PrintTypedDoc(const ForDoc& doc) final;
void PrintTypedDoc(const ExprStmtDoc& doc) final;
void PrintTypedDoc(const AssertDoc& doc) final;
void PrintTypedDoc(const ReturnDoc& doc) final;
void PrintTypedDoc(const ScopeDoc& doc) final;
void PrintTypedDoc(const FunctionDoc& doc) final;
void PrintTypedDoc(const ClassDoc& doc) final;

private:
void NewLineWithoutIndent() { output_ << "\n"; }

template <typename DocType>
void PrintJoinedDocs(const Array<DocType>& docs, const std::string& separator) {
bool is_first = true;
Expand All @@ -59,6 +76,65 @@ class PythonDocPrinter : public DocPrinter {
PrintDoc(doc);
}
}

void PrintIndentedBlock(const Array<StmtDoc>& docs) {
IncreaseIndent();
for (const StmtDoc& d : docs) {
NewLine();
PrintDoc(d);
}
if (docs.empty()) {
NewLine();
output_ << "pass";
}
DecreaseIndent();
}

void PrintDecorators(const Array<ExprDoc>& decorators) {
for (const ExprDoc& decorator : decorators) {
output_ << "@";
PrintDoc(decorator);
NewLine();
}
}

void MaybePrintCommentInline(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
const std::string& comment = stmt->comment.value();
bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end();
CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey()
<< " cannot have newline.";
output_ << " # " << comment;
}
}

void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n');
for (const std::string& line : comment_lines) {
output_ << "# " << line;
NewLine();
}
}
}

void PrintBlockComment(const String& comment) {
IncreaseIndent();
NewLine() << "\"\"\"";

std::vector<std::string> comment_lines = support::Split(comment, '\n');
for (const std::string& line : comment_lines) {
if (line.empty()) {
// No indentation on empty line
output_ << "\n";
} else {
NewLine() << line;
}
}

NewLine() << "\"\"\"";
DecreaseIndent();
}
};

void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
Expand Down Expand Up @@ -260,6 +336,140 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
}
}

void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
for (const StmtDoc& stmt : doc->stmts) {
PrintDoc(stmt);
NewLine();
}
}

void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) {
PrintJoinedDocs(tuple_doc->elements, ", ");
} else {
PrintDoc(doc->lhs);
}

if (doc->annotation) {
output_ << ": ";
PrintDoc(doc->annotation.value());
}
if (doc->rhs) {
output_ << " = ";
PrintDoc(doc->rhs.value());
}
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "if ";
PrintDoc(doc->predicate);
output_ << ":";

PrintIndentedBlock(doc->then_branch);

if (!doc->else_branch.empty()) {
NewLine();
output_ << "else:";
PrintIndentedBlock(doc->else_branch);
}
}

void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "while ";
PrintDoc(doc->predicate);
output_ << ":";

PrintIndentedBlock(doc->body);
}

void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "for ";
PrintDoc(doc->lhs);
output_ << " in ";
PrintDoc(doc->rhs);
output_ << ":";

PrintIndentedBlock(doc->body);
}

void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "with ";
PrintDoc(doc->rhs);
if (doc->lhs != nullptr) {
output_ << " as ";
PrintDoc(doc->lhs.value());
}
output_ << ":";

PrintIndentedBlock(doc->body);
}

void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) {
PrintDoc(doc->expr);
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) {
output_ << "assert ";
PrintDoc(doc->test);
if (doc->msg.defined()) {
output_ << ", ";
PrintDoc(doc->msg.value());
}
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) {
output_ << "return ";
PrintDoc(doc->value);
MaybePrintCommentInline(doc);
}

void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
for (const AssignDoc& arg_doc : doc->args) {
ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them.";
}

PrintDecorators(doc->decorators);

output_ << "def ";
PrintDoc(doc->name);

output_ << "(";
PrintJoinedDocs(doc->args, ", ");
output_ << ")";

output_ << " -> ";
PrintDoc(doc->return_type);

output_ << ":";

if (doc->comment.defined()) {
PrintBlockComment(doc->comment.value());
}
PrintIndentedBlock(doc->body);
NewLineWithoutIndent();
}

void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
PrintDecorators(doc->decorators);

output_ << "class ";
PrintDoc(doc->name);
output_ << ":";

if (doc->comment.defined()) {
PrintBlockComment(doc->comment.value());
}
PrintIndentedBlock(doc->body);
NewLineWithoutIndent();
}

String DocToPythonScript(Doc doc, int indent_spaces) {
PythonDocPrinter printer(indent_spaces);
printer.Append(doc);
Expand Down
Loading

0 comments on commit 6a94245

Please sign in to comment.