Skip to content

Commit

Permalink
[TIR, TVMScript] Update printer / parser to make T.allocate return bu…
Browse files Browse the repository at this point in the history
…ffer var
  • Loading branch information
vinx13 committed Aug 12, 2022
1 parent 1f97f1f commit 628b844
Show file tree
Hide file tree
Showing 21 changed files with 432 additions and 351 deletions.
57 changes: 30 additions & 27 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,17 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None)
scope = tvm.runtime.convert(scope)

return tvm.tir.Allocate(
self.buffer.data,
self.buffer.dtype,
self.buffer.shape,
self.buffer_var,
dtype,
extents,
condition,
self.body,
annotations=annotations,
span=span,
)

super().__init__(allocate, concise_scope=True, def_symbol=True)
self.buffer = None
self.buffer_var = None

def enter_scope(
self,
Expand All @@ -146,20 +146,15 @@ def enter_scope(
else:
raise Exception("Internal Bug")

def setup_buffer(
def setup_buffer_var(
extents, dtype, scope, condition=True, annotations=None, span: Span = None
):
"""Setup buffer object for a given type."""
self.buffer = tvm.tir.decl_buffer(
shape=extents,
dtype=dtype,
name=name,
scope=scope,
span=span,
)
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer, node)
setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer_var, node)


@register
Expand All @@ -176,7 +171,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
list_data.append(i.value)
nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
n = tvm.tir.AllocateConst(
self.buffer.data,
self.buffer_var,
dtype,
shape,
nd_data,
Expand All @@ -187,7 +182,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
return n

super().__init__(allocate_const, concise_scope=True, def_symbol=True)
self.buffer = None
self.buffer_var = None

def enter_scope(
self,
Expand All @@ -211,17 +206,13 @@ def enter_scope(
else:
raise Exception("Internal Bug")

def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None):
def setup_buffer_var(data, dtype, shape, annotations: dict = None, span: Span = None):
"""Setup buffer var for a given type."""
self.buffer = tvm.tir.decl_buffer(
shape=shape,
dtype=dtype,
name=name,
span=span,
)
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer, node)
setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer_var, node)


@register
Expand All @@ -248,7 +239,18 @@ def decl_buffer(
axis_separators=None,
span=None,
):
return tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
decl_buffer = tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
if data is None:
# when data is not specified, the buffer is implicitly allocated
return tvm.tir.Allocate(
self.buffer.data,
dtype,
shape,
tvm.runtime.convert(True),
decl_buffer,
span=span,
)
return decl_buffer

super().__init__(decl_buffer, concise_scope=True, def_symbol=True)

Expand Down Expand Up @@ -298,6 +300,7 @@ def setup_buffer(
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
name=name,
span=span,
)

Expand Down
127 changes: 66 additions & 61 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,21 @@ class BufferUsageFinder : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const DeclBufferNode* op) final {
buffers_declared_.insert(op->buffer.get());
StmtExprVisitor::VisitStmt_(op);
}

private:
explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}

void VisitBuffer(const Buffer& buffer) {
if (buffers_visited_.count(buffer.get())) {
return;
}
if (buffers_declared_.count(buffer.get())) {
return;
}
buffers_visited_.insert(buffer.get());

Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
Expand All @@ -119,6 +127,9 @@ class BufferUsageFinder : public StmtExprVisitor {
// The buffers that have been visited so far, to avoid duplicate
// entries in the search result.
std::unordered_set<const BufferNode*> buffers_visited_;
// The buffers declared via `DeclBuffer`. These buffers are excluded from the result because
// T.buffer_decl shouldn't be printed for them.
std::unordered_set<const BufferNode*> buffers_declared_;
};

/*!
Expand Down Expand Up @@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
return Doc();
}

namespace {
struct AllocUsage {
Buffer alloc_buffer;
Array<Buffer> aliasing_buffers;
};

template <typename AllocNode>
AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
Map<Var, Array<Buffer>>& cache = *cache_ptr;
if (!cache.count(op->buffer_var)) {
cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const DeclBufferNode* decl_buffer) {
const Var& buffer_var = allocate->buffer_var;
const Buffer& buffer = decl_buffer->buffer;
if (!buffer_var.same_as(buffer->data)) {
return false;
}
Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});

auto is_exact_match = [](Buffer a, Buffer b) {
if (a->dtype != b->dtype) return false;
if (a->shape.size() != b->shape.size()) return false;

arith::Analyzer analyzer;
for (size_t i = 0; i < a->shape.size(); i++) {
if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
return false;
}
}
return true;
};

// If the buffer allocated via T.allocate is an exact match to the
// usage of the buffer later on, then that buffer is the return
// value of T.allocate, and no T.buffer_decl statement is needed.
Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0,
0, kDefault);
bool found_alloc_buf = false;
Array<Buffer> aliasing_buffers;
for (const auto& buf : buffer_usage) {
if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
alloc_buffer = buf;
found_alloc_buf = true;
} else {
aliasing_buffers.push_back(buf);
if (allocate->dtype != buffer->dtype) {
return false;
}
if (!is_one(allocate->condition)) {
return false;
}
if (allocate->annotations.size()) {
return false;
}
if (allocate->extents.size() != buffer->shape.size()) {
return false;
}
tir::ExprDeepEqual expr_equal;
for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
return false;
}
}

return AllocUsage{alloc_buffer, aliasing_buffers};
return true;
}
} // namespace

Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
auto usage = FindAllocateUsage(op, &buffer_var_usage_);
Buffer& alloc_buffer = usage.alloc_buffer;
Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
buf_not_in_headers_.insert(alloc_buffer.get());
var_not_in_headers_.insert(alloc_buffer->data.get());
var_not_in_headers_.insert(op->buffer_var.get());

if (!buffer_var_usage_.count(op->buffer_var)) {
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
}
Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({});

if (buffer_usage.empty()) {
if (const DeclBufferNode* decl_buffer = op->body.as<DeclBufferNode>()) {
if (IsAllocateDeclBufferPattern(op, decl_buffer)) {
// As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single
// DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to
// delegate the printing of the current node to `DeclBufferNode` while maintaining the
// same value of `current_num_` and `num_child_`.
return Print(op->body);
}
}
}

auto storage_scope = GetPtrStorageScope(op->buffer_var);
Doc func_call;
Expand All @@ -1098,12 +1102,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {

Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << func_call << " as " << Print(alloc_buffer) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers)
<< PrintBody(op->body));
doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(
4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body));
} else {
doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(op->body);
doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body);
}
TryDeallocVar(op->buffer_var);
return doc;
Expand Down Expand Up @@ -1139,11 +1143,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
}
auto ndarray_str = ss.str();

auto usage = FindAllocateUsage(alloc, &buffer_var_usage_);
Buffer& alloc_buffer = usage.alloc_buffer;
Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
buf_not_in_headers_.insert(alloc_buffer.get());
var_not_in_headers_.insert(alloc_buffer->data.get());
var_not_in_headers_.insert(alloc->buffer_var.get());

if (!buffer_var_usage_.count(alloc->buffer_var)) {
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), alloc->body);
}
Array<Buffer> buffer_usage = buffer_var_usage_.Get(alloc->buffer_var).value_or({});

Doc func_call;
func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype)
Expand All @@ -1152,12 +1157,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
Doc doc;
var_not_in_headers_.insert(alloc->buffer_var.get());
if (current_num_ != num_child_ - 1) {
doc << "with " << func_call << " as " << Print(alloc_buffer) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers)
doc << "with " << func_call << " as " << Print(alloc->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage)
<< PrintBody(alloc->body));
} else {
doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(alloc->body);
doc << Print(alloc->buffer_var) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(alloc->body);
}
return doc;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def main(a: T.handle, b: T.handle) -> None:
B = T.match_buffer(b, [14*14*512*256], dtype="float32")
# body
T.launch_thread(blockIdx_z, 196)
B_local = T.allocate([64], "float32", "local")
Apad_shared = T.allocate([512], "float32", "shared")
Apad_shared_local = T.allocate([8], "float32", "local")
B_local = T.decl_buffer([64], "float32", scope="local")
Apad_shared = T.decl_buffer([512], "float32", scope="shared")
Apad_shared_local = T.decl_buffer([8], "float32", scope="local")
T.launch_thread(blockIdx_y, 8)
T.launch_thread(blockIdx_x, 4)
T.launch_thread(threadIdx_y, 8)
Expand Down Expand Up @@ -105,9 +105,9 @@ def main(a: T.handle, b: T.handle) -> None:
B = T.match_buffer(b, [14*14*512*256], dtype="float32")
# body
T.launch_thread(blockIdx_z, 196)
B_local = T.allocate([6400000], "float32", "local")
Apad_shared = T.allocate([512], "float32", "shared")
Apad_shared_local = T.allocate([8], "float32", "local")
B_local = T.decl_buffer([6400000], "float32", scope="local")
Apad_shared = T.decl_buffer([512], "float32", scope="shared")
Apad_shared_local = T.decl_buffer([8], "float32", scope="local")
T.launch_thread(blockIdx_y, 8)
T.launch_thread(blockIdx_x, 4)
T.launch_thread(threadIdx_y, 8)
Expand Down Expand Up @@ -151,9 +151,9 @@ def main(a: T.handle, b: T.handle) -> None:
B = T.match_buffer(b, [14*14*512*256], dtype="float32")
# body
T.launch_thread(blockIdx_z, 196)
B_local = T.allocate([64], "float32", "local")
Apad_shared = T.allocate([512000], "float32", "shared")
Apad_shared_local = T.allocate([8], "float32", "local")
B_local = T.decl_buffer([64], "float32", scope="local")
Apad_shared = T.decl_buffer([512000], "float32", scope="shared")
Apad_shared_local = T.decl_buffer([8], "float32", scope="local")
T.launch_thread(blockIdx_y, 8)
T.launch_thread(blockIdx_x, 4)
T.launch_thread(threadIdx_y, 8)
Expand Down Expand Up @@ -197,9 +197,9 @@ def main(a: T.handle, b: T.handle) -> None:
B = T.match_buffer(b, [14*14*512*256], dtype="float32")
# body
T.launch_thread(blockIdx_z, 196)
B_local = T.allocate([64], "float32", "local")
Apad_shared = T.allocate([512], "float32", "shared")
Apad_shared_local = T.allocate([8], "float32", "local")
B_local = T.decl_buffer([64], "float32", scope="local")
Apad_shared = T.decl_buffer([512], "float32", scope="shared")
Apad_shared_local = T.decl_buffer([8], "float32", scope="local")
T.launch_thread(blockIdx_y, 8)
T.launch_thread(blockIdx_x, 4)
T.launch_thread(threadIdx_y, 8)
Expand Down
18 changes: 9 additions & 9 deletions tests/python/unittest/test_tir_analysis_calculate_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.hand
placeholder_149 = T.match_buffer(placeholder_146, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1)
T_cast_49 = T.match_buffer(T_cast_48, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1)
# body
PaddedInput_22 = T.allocate([131072], "int16", "global")
DepthwiseConv2d_9 = T.allocate([100352], "int32", "global")
PaddedInput_22 = T.decl_buffer([131072], "int16", scope="global")
DepthwiseConv2d_9 = T.decl_buffer([100352], "int32", scope="global")
for i1_29, i2_39, i3_40 in T.grid(16, 16, 512):
PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), placeholder_147[((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)], T.int16(0), dtype="int16")
for i_9, j_9, c_9 in T.grid(14, 14, 512):
Expand Down Expand Up @@ -63,25 +63,25 @@ def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handl
T_cast_77 = T.match_buffer(T_cast_76, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1)
sid_21 = T.allocate_const([0,1,2,3,4,5,6,7], "int8", [8])
# body
PaddedInput_25 = T.allocate([131072], "int16", "global")
PaddedInput_25 = T.decl_buffer([131072], "int16", scope="global")
for i1_35, i2_46, i3_47 in T.grid(16, 16, 512):
PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), placeholder_165[((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)], T.int16(0), dtype="int16")
T_add_11 = T.allocate([100352], "int32", "global")
with T.allocate([100352], "int32", "global") as DepthwiseConv2d_11:
T_add_11 = T.decl_buffer([100352], "int32", scope="global")
with T.decl_buffer([100352], "int32", scope="global") as DepthwiseConv2d_11:
for i_11, j_11, c_11 in T.grid(14, 14, 512):
DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0
for di_11, dj_11 in T.grid(3, 3):
DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] + (PaddedInput_25[(((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)].astype("int32")*placeholder_166[(((di_11*1536) + (dj_11*512)) + c_11)].astype("int32")))
for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512):
T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (DepthwiseConv2d_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] + placeholder_167[ax3_47])
compute_22 = T.allocate([100352], "int32", "global")
with T.allocate([100352], "int32", "global") as T_cast_78:
compute_22 = T.decl_buffer([100352], "int32", data="global")
with T.decl_buffer([100352], "int32", scope="global") as T_cast_78:
for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512):
T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T_add_11[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)]
for i1_36, i2_47, i3_48 in T.grid(14, 14, 512):
compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T_cast_78[(((i1_36*7168) + (i2_47*512)) + i3_48)], 1948805937, 31, -5, dtype="int32")
T_cast_79 = T.allocate([100352], "uint8", "global")
with T.allocate([100352], "int32", "global") as compute_23:
T_cast_79 = T.decl_buffer([100352], "uint8", scope="global")
with T.decl_buffer([100352], "int32", scope="global") as compute_23:
for i1_37, i2_48, i3_49 in T.grid(14, 14, 512):
compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(compute_22[(((i1_37*7168) + (i2_48*512)) + i3_49)], 255), 0)
for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512):
Expand Down
Loading

0 comments on commit 628b844

Please sign in to comment.