Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, TVMScript] Update printer / parser to make T.allocate return buffer var #12412

Merged
merged 8 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 30 additions & 27 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,17 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None)
scope = tvm.runtime.convert(scope)

return tvm.tir.Allocate(
self.buffer.data,
self.buffer.dtype,
self.buffer.shape,
self.buffer_var,
dtype,
extents,
condition,
self.body,
annotations=annotations,
span=span,
)

super().__init__(allocate, concise_scope=True, def_symbol=True)
self.buffer = None
self.buffer_var = None

def enter_scope(
self,
Expand All @@ -146,20 +146,15 @@ def enter_scope(
else:
raise Exception("Internal Bug")

def setup_buffer(
def setup_buffer_var(
extents, dtype, scope, condition=True, annotations=None, span: Span = None
):
"""Setup buffer object for a given type."""
self.buffer = tvm.tir.decl_buffer(
shape=extents,
dtype=dtype,
name=name,
scope=scope,
span=span,
)
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer, node)
setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer_var, node)


@register
Expand All @@ -176,7 +171,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
list_data.append(i.value)
nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
n = tvm.tir.AllocateConst(
self.buffer.data,
self.buffer_var,
dtype,
shape,
nd_data,
Expand All @@ -187,7 +182,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
return n

super().__init__(allocate_const, concise_scope=True, def_symbol=True)
self.buffer = None
self.buffer_var = None

def enter_scope(
self,
Expand All @@ -211,17 +206,13 @@ def enter_scope(
else:
raise Exception("Internal Bug")

def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None):
def setup_buffer_var(data, dtype, shape, annotations: dict = None, span: Span = None):
"""Setup buffer var for a given type."""
self.buffer = tvm.tir.decl_buffer(
shape=shape,
dtype=dtype,
name=name,
span=span,
)
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer, node)
setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer_var, node)


@register
Expand All @@ -248,7 +239,18 @@ def decl_buffer(
axis_separators=None,
span=None,
):
return tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
decl_buffer = tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
if data is None:
# when data is not specified, the buffer is implicitly allocated
return tvm.tir.Allocate(
self.buffer.data,
dtype,
shape,
tvm.runtime.convert(True),
decl_buffer,
span=span,
)
return decl_buffer

super().__init__(decl_buffer, concise_scope=True, def_symbol=True)

Expand Down Expand Up @@ -298,6 +300,7 @@ def setup_buffer(
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
name=name,
span=span,
)

Expand Down
128 changes: 70 additions & 58 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,22 @@ class BufferUsageFinder : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const DeclBufferNode* op) final {
buffers_declared_.insert(op->buffer.get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also track which buffers have gone out of scope? If I'm understanding it correctly, a single DeclBufferNode would also allow for usage outside of the DeclBufferNode::body, where I'd expect it to only apply within the scope of the node.

StmtExprVisitor::VisitStmt_(op);
buffers_declared_.erase(op->buffer.get());
}

private:
explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}

void VisitBuffer(const Buffer& buffer) {
if (buffers_visited_.count(buffer.get())) {
return;
}
if (buffers_declared_.count(buffer.get())) {
return;
}
buffers_visited_.insert(buffer.get());

Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
Expand All @@ -119,6 +128,9 @@ class BufferUsageFinder : public StmtExprVisitor {
// The buffers that have been visited so far, to avoid duplicate
// entries in the search result.
std::unordered_set<const BufferNode*> buffers_visited_;
// The buffers declared via `DeclBuffer`. These buffers are excluded from the result because
// T.buffer_decl shouldn't be printed for them.
std::unordered_set<const BufferNode*> buffers_declared_;
};

/*!
Expand Down Expand Up @@ -1055,58 +1067,57 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
}

namespace {
struct AllocUsage {
Buffer alloc_buffer;
Array<Buffer> aliasing_buffers;
};

template <typename AllocNode>
AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
Map<Var, Array<Buffer>>& cache = *cache_ptr;
if (!cache.count(op->buffer_var)) {
cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) {
const Var& buffer_var = allocate->buffer_var;
const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>();
if (!decl_buffer) {
return false;
}
Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});

auto is_exact_match = [](Buffer a, Buffer b) {
if (a->dtype != b->dtype) return false;
if (a->shape.size() != b->shape.size()) return false;

arith::Analyzer analyzer;
for (size_t i = 0; i < a->shape.size(); i++) {
if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
return false;
}
}
return true;
};

// If the buffer allocated via T.allocate is an exact match to the
// usage of the buffer later on, then that buffer is the return
// value of T.allocate, and no T.buffer_decl statement is needed.
Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0,
0, kDefault);
bool found_alloc_buf = false;
Array<Buffer> aliasing_buffers;
for (const auto& buf : buffer_usage) {
if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
alloc_buffer = buf;
found_alloc_buf = true;
} else {
aliasing_buffers.push_back(buf);
const Buffer& buffer = decl_buffer->buffer;
if (!buffer_var.same_as(buffer->data)) {
return false;
}
if (allocate->dtype != buffer->dtype) {
return false;
}
if (!is_one(allocate->condition)) {
return false;
}
if (allocate->annotations.size()) {
return false;
}
if (allocate->extents.size() != buffer->shape.size()) {
return false;
}
tir::ExprDeepEqual expr_equal;
for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
return false;
}
}

return AllocUsage{alloc_buffer, aliasing_buffers};
return true;
}

} // namespace

Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
auto usage = FindAllocateUsage(op, &buffer_var_usage_);
Buffer& alloc_buffer = usage.alloc_buffer;
Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
buf_not_in_headers_.insert(alloc_buffer.get());
var_not_in_headers_.insert(alloc_buffer->data.get());
var_not_in_headers_.insert(op->buffer_var.get());

if (!buffer_var_usage_.count(op->buffer_var)) {
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
}
Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({});

if (buffer_usage.empty()) {
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
if (IsAllocateDeclBufferPattern(op)) {
// As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single
// DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to
// delegate the printing of the current node to `DeclBufferNode` while maintaining the
// same value of `current_num_` and `num_child_`.
return Print(op->body);
}
}

auto storage_scope = GetPtrStorageScope(op->buffer_var);
Doc func_call;
Expand All @@ -1124,12 +1135,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {

Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << func_call << " as " << Print(alloc_buffer) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers)
<< PrintBody(op->body));
doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(
4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body));
} else {
doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(op->body);
doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body);
}
TryDeallocVar(op->buffer_var);
return doc;
Expand Down Expand Up @@ -1179,11 +1190,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
}
auto ndarray_str = ss.str();

auto usage = FindAllocateUsage(alloc, &buffer_var_usage_);
Buffer& alloc_buffer = usage.alloc_buffer;
Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
buf_not_in_headers_.insert(alloc_buffer.get());
var_not_in_headers_.insert(alloc_buffer->data.get());
var_not_in_headers_.insert(alloc->buffer_var.get());

if (!buffer_var_usage_.count(alloc->buffer_var)) {
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), alloc->body);
}
Array<Buffer> buffer_usage = buffer_var_usage_.Get(alloc->buffer_var).value_or({});

Doc func_call;
func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype)
Expand All @@ -1192,12 +1204,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
Doc doc;
var_not_in_headers_.insert(alloc->buffer_var.get());
if (current_num_ != num_child_ - 1) {
doc << "with " << func_call << " as " << Print(alloc_buffer) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers)
doc << "with " << func_call << " as " << Print(alloc->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage)
<< PrintBody(alloc->body));
} else {
doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(alloc->body);
doc << Print(alloc->buffer_var) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(alloc->body);
}
return doc;
}
Expand Down
Loading