Skip to content

Commit

Permalink
Adding annotations for tir.allocate (apache#9168)
Browse files Browse the repository at this point in the history
* Adding annotation for tir.allocate

This commit is adding annotations for tir.allocate
node to be used as hints for future transformations.

Change-Id: I02a3a875c38c3edd449385da5b741ef4958bb47f

* Adding annotation for tir.allocate

* adding tvmscript support
* adding tir text printing support

Change-Id: Id0b6725b2e79c23f6b8ff192772f1ea4125a27c2
  • Loading branch information
manupak authored and ylc committed Jan 7, 2022
1 parent 1b82b3e commit b2dbafd
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 13 deletions.
14 changes: 12 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,20 +521,28 @@ class AllocateNode : public StmtNode {
PrimExpr condition;
/*! \brief The body to be executed. */
Stmt body;
/*!
* \brief Additional annotations about the allocation.
*
* These annotations can be used as auxiliary hint
* to future transformations.
*/
Map<String, ObjectRef> annotations;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
}

bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
equal(extents, other->extents) && equal(condition, other->condition) &&
equal(body, other->body);
equal(body, other->body) && equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -543,6 +551,7 @@ class AllocateNode : public StmtNode {
hash_reduce(extents);
hash_reduce(condition);
hash_reduce(body);
hash_reduce(annotations);
}

/*!
Expand Down Expand Up @@ -570,7 +579,8 @@ class AllocateNode : public StmtNode {
class Allocate : public Stmt {
public:
TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span = Span());
Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,20 @@ def get_optional_vars(node, context):

@register
class Allocate(WithScopeHandler):
"""With scope handler T.allocate(extents, dtype, scope, condition)"""
"""With scope handler T.allocate(extents, dtype, scope, condition, annotations)"""

def __init__(self):
def allocate(extents, dtype, scope, condition=True, span=None):
def allocate(extents, dtype, scope, condition=True, annotations=None, span=None):
condition = tvm.runtime.convert(condition)
scope = tvm.runtime.convert(scope)
return tvm.tir.Allocate(
self.buffer_var, dtype, extents, condition, self.body, span=span
self.buffer_var,
dtype,
extents,
condition,
self.body,
annotations=annotations,
span=span,
)

super().__init__(allocate, concise_scope=True, def_symbol=True)
Expand All @@ -137,7 +143,9 @@ def enter_scope(
else:
raise Exception("Internal Bug")

def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None):
def setup_buffer_var(
extents, dtype, scope, condition=True, annotations=None, span: Span = None
):
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
Expand Down
16 changes: 14 additions & 2 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,25 @@ class Allocate(Stmt):
body : Stmt
The body statement.
annotations: Optional[Mapping[str, Object]]
Additional annotation hints
span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None):
if annotations is None:
annotations = dict()
self.__init_handle_by_constructor__(
_ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span # type: ignore
_ffi_api.Allocate, # type: ignore
buffer_var,
dtype,
extents,
condition,
body,
annotations,
span,
)


Expand Down
12 changes: 10 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,16 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
Doc doc;
auto scope = GetPtrStorageScope(op->buffer_var);
doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
<< Print(op->extents) << "), storage_scope = " << scope;
doc << "allocate(" << Print(op->buffer_var) << ", ";
doc << PrintDType(op->dtype) << ", ";
doc << Print(op->extents) << "), storage_scope = " << scope;
if (!op->annotations.empty()) {
std::vector<Doc> attr_docs;
for (const auto& it : op->annotations) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
doc << ", annotations = {" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
}
if (!is_one(op->condition)) {
doc << " if " << Print(op->condition);
}
Expand Down
10 changes: 10 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ") as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
} else {
Expand All @@ -777,6 +782,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ")" << Doc::NewLine() << PrintBody(op->body);
}
TryDeallocVar(op->buffer_var);
Expand Down
7 changes: 4 additions & 3 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

// Allocate
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span) {
Stmt body, Map<String, ObjectRef> annotations, Span span) {
CHECK(IsPointerType(buffer_var->type_annotation, dtype))
<< "The allocated data type (" << dtype
<< ") does not match the type annotation of the buffer " << buffer_var << " ("
Expand All @@ -354,6 +354,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, Prim
node->extents = std::move(extents);
node->condition = std::move(condition);
node->body = std::move(body);
node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}
Expand All @@ -375,8 +376,8 @@ int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {

TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span) {
return Allocate(buffer_var, type, extents, condition, body, span);
Stmt body, Map<String, ObjectRef> annotations, Span span) {
return Allocate(buffer_var, type, extents, condition, body, annotations, span);
});

TVM_REGISTER_NODE_TYPE(AllocateNode);
Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_tir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,38 @@ def test_block_blockrealize():
assert output.find("with init()") != -1


def test_tir_allocate():
dtype = "int8"
storage_scope = "global"
ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
a = te.var("buffer", ptype)
allocate = tvm.tir.Allocate(
buffer_var=a,
dtype=dtype,
extents=[2, 2],
condition=tvm.get_global_func("tir.const_true")(dtype, None),
body=tvm.tir.Evaluate(2 + 1),
annotations={
"attr1": "foo",
"attr2": "bar",
},
)
assert allocate.buffer_var == a
assert allocate.dtype == "int8"
assert list(allocate.extents) == [2, 2]
assert allocate.annotations["attr1"] == "foo"
assert allocate.annotations["attr2"] == "bar"

# make sure we can print using TIRTextPrinter
func = tvm.tir.PrimFunc([], allocate)
output = func.astext()
assert (
output.find(
'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})'
)
!= -1
)


if __name__ == "__main__":
pytest.main([__file__])
27 changes: 27 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3059,5 +3059,32 @@ def test_while_loop():
tvm.ir.assert_structural_equal(while_loop, rt_func)


# fmt: off
@T.prim_func
def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True})
placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
# body
tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"})
for ax0_ax1_fused_4 in T.serial(0, 56):
for ax2_4 in T.serial(0, 56):
for ax3_init in T.serial(0, 64):
T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True)
for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True)
for ax0_ax1_fused_5 in T.serial(0, 56):
for ax2_5, ax3_3 in T.grid(56, 64):
T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)
# fmt: on


def test_primfunc_with_allocate_annotations():
func = primfunc_with_allocate_annotations
rt_func = tvm.script.from_source(func.script(show_meta=True))
tvm.ir.assert_structural_equal(func, rt_func, True)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit b2dbafd

Please sign in to comment.