Skip to content

Commit

Permalink
[TVMScript][Fix] Correct round-trip of explicit root block (apache#12673
Browse files Browse the repository at this point in the history
)

* [TVMScript][Fix] Correct round-trip of explicit root block

Prior to this commit, when converting TIR to TVMScript, the root
`tir::Block` is typically hidden.  When parsing, however,
`tvm::tir::ScriptComplete` will wrap the function body in a root block
if the primfunc if the contains at least one block and does not
already have a root block.  As a result, if the root block is the only
block present, it would be stripped by a round-trip.

This commit tightens the condition for hiding the root `tir::Block`
when converting to TVMScript, so that it is printed in cases where
the autocompleter would reinsert it when parsing.
  • Loading branch information
Lunderberg authored Sep 21, 2022
1 parent b051cad commit fdc6894
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 17 deletions.
32 changes: 32 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,38 @@ TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
* \return The renewed func.
*/
TVM_DLL PrimFunc RenewDefs(const PrimFunc& func);

/*!
* \brief Check if the statement contains the specified node type.
*
* This utility potentially walks the entire statement, and should
* therefore not be used if it could otherwise be merged with another
* pass.
*
* \param stmt The statement to be searched
* \return Whether stmt contains Node
*/
template <typename Node, typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
bool ContainsNode(const Stmt& stmt) {
struct Visitor : StmtVisitor {
// Early bail-out, if we already found the node.
void VisitStmt(const Stmt& stmt) {
if (contains_node) {
return;
}
StmtVisitor::VisitStmt(stmt);
}

void VisitStmt_(const Node* block) override { contains_node = true; }

bool contains_node{false};
};

Visitor visitor;
visitor(stmt);
return visitor.contains_node;
}

} // namespace tir
} // namespace tvm

Expand Down
50 changes: 42 additions & 8 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1664,19 +1664,53 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
}
// print body
body << "# body" << Doc::NewLine();
if (op->body->IsInstance<BlockRealizeNode>() &&
op->body.as<BlockRealizeNode>()->iter_values.empty()) {
const BlockNode* block = op->body.as<BlockRealizeNode>()->block.get();
if (block->annotations.empty() && !ContainsOptionalInfo(GetRef<Stmt>(block))) {
// Skip print root block
body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
body << PrintBlockBody(block);

Optional<Block> elided_root_block_body = [&]() -> Optional<Block> {
auto block_realize = op->body.as<BlockRealizeNode>();
if (!block_realize || block_realize->iter_values.size()) {
return NullOpt;
}

const auto& block = block_realize->block;
if (block->annotations.size() || ContainsOptionalInfo(block)) {
return NullOpt;
}

// The autocomplete might recognize the body itself as being a
// root block, and fail to insert it.
bool autocomplete_would_insert_root_block = [&]() -> bool {
if (block->alloc_buffers.size()) {
return true;
}

auto* block_realize = block->body.as<BlockRealizeNode>();
if (block_realize && block_realize->block->iter_vars.size()) {
return true;
}
if (!block_realize && ContainsNode<BlockRealizeNode>(block->body)) {
return true;
}
return false;
}();

if (autocomplete_would_insert_root_block) {
return block;
} else {
body << PrintBody(op->body);
return NullOpt;
}
}();

if (elided_root_block_body) {
// Skip printing of root block in cases where tvm::tir::ScriptComplete
// would re-insert it.
body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
body << PrintBlockBody(elided_root_block_body.value().get());
} else {
// If this is a non-root block, or is an unskippable root block,
// just print it without skipping.
body << PrintBody(op->body);
}

// print func attrs
Doc header_attr;
if (primFunc->attrs.defined()) {
Expand Down
37 changes: 28 additions & 9 deletions src/tir/ir/script/script_complete.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,35 @@ PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
for (const auto& alloc : root_allocates) {
buffer_var_map.Set(alloc->data, alloc);
}
bool contain_root = root_allocates.empty() && func->body->IsInstance<BlockRealizeNode>() &&
Downcast<BlockRealize>(func->body)->block->iter_vars.empty();
ScriptCompleter script_completer(&buffer_var_map);
// generate surrounding loops automatically
Stmt res = script_completer(func->body);
// generate root block automatically
if ((script_completer.contains_block || root_allocates.size()) && !contain_root) {
res = Block({}, {}, {}, "root", res, NullOpt, root_allocates);
res = BlockRealize({}, Bool(true), Downcast<Block>(res));

Stmt res = func->body;

// Generate root block automatically. This is done before
// ScriptCompleter, in order to fill the root block's T.reads() and
// T.writes() annotations, as if it had been explicitly written.
bool should_insert_root = [&]() -> bool {
if (root_allocates.size()) {
return true;
}
auto* block_realize = func->body.as<BlockRealizeNode>();
if (block_realize && block_realize->block->iter_vars.size()) {
return true;
}
if (!block_realize && ContainsNode<BlockRealizeNode>(func->body)) {
return true;
}
return false;
}();

if (should_insert_root) {
Block root_block({}, {}, {}, "root", std::move(res), NullOpt, root_allocates);
res = BlockRealize({}, Bool(true), std::move(root_block));
}

// generate surrounding loops automatically
ScriptCompleter script_completer(&buffer_var_map);
res = script_completer(std::move(res));

if (func->body.same_as(res)) {
return func;
} else {
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3142,6 +3142,25 @@ def func_root_attr():
return func_root_attr


def func_trivial_root_block():
@T.prim_func
def func(A: T.Buffer[1, "int32"]):
with T.block("root"):
A[0] = 0

return func


def func_nested_root_block():
@T.prim_func
def func(A: T.Buffer[1, "int32"]):
with T.block("root"):
with T.block("block"):
A[0] = 0

return func


def func_T_ptr_let_statement():
@T.prim_func
def func_T_ptr_let_statement(
Expand Down Expand Up @@ -3418,6 +3437,8 @@ def func() -> None:
func_with_target_spec_by_config,
func_with_target_spec_by_str,
func_root_attr,
func_trivial_root_block,
func_nested_root_block,
func_T_ptr_let_statement,
func_T_ptr_allocate,
llvm_intrin_call,
Expand Down

0 comments on commit fdc6894

Please sign in to comment.