From b2dbafdcb825d5d290d36c4dff839ac36eb20440 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 7 Oct 2021 13:44:59 +0100 Subject: [PATCH] Adding annotations for tir.allocate (#9168) * Adding annotation for tir.allocate This commit is adding annotations for tir.allocate node to be used as hints for future transformations. Change-Id: I02a3a875c38c3edd449385da5b741ef4958bb47f * Adding annotation for tir.allocate * adding tvmscript support * adding tir text printing support Change-Id: Id0b6725b2e79c23f6b8ff192772f1ea4125a27c2 --- include/tvm/tir/stmt.h | 14 ++++++-- python/tvm/script/tir/scope_handler.py | 16 ++++++--- python/tvm/tir/stmt.py | 16 +++++++-- src/printer/tir_text_printer.cc | 12 +++++-- src/printer/tvmscript_printer.cc | 10 ++++++ src/tir/ir/stmt.cc | 7 ++-- tests/python/unittest/test_tir_nodes.py | 33 +++++++++++++++++++ .../unittest/test_tvmscript_roundtrip.py | 27 +++++++++++++++ 8 files changed, 122 insertions(+), 13 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 94ba853c493a..5cd860b8e929 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -521,6 +521,13 @@ class AllocateNode : public StmtNode { PrimExpr condition; /*! \brief The body to be executed. */ Stmt body; + /*! + * \brief Additional annotations about the allocation. + * + * These annotations can be used as auxiliary hint + * to future transformations. + */ + Map annotations; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); @@ -528,13 +535,14 @@ class AllocateNode : public StmtNode { v->Visit("extents", &extents); v->Visit("condition", &condition); v->Visit("body", &body); + v->Visit("annotations", &annotations); v->Visit("span", &span); } bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && equal(extents, other->extents) && equal(condition, other->condition) && - equal(body, other->body); + equal(body, other->body) && equal(annotations, other->annotations); } void SHashReduce(SHashReducer hash_reduce) const { @@ -543,6 +551,7 @@ class AllocateNode : public StmtNode { hash_reduce(extents); hash_reduce(condition); hash_reduce(body); + hash_reduce(annotations); } /*! @@ -570,7 +579,8 @@ class AllocateNode : public StmtNode { class Allocate : public Stmt { public: TVM_DLL Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Span span = Span()); + Stmt body, Map annotations = Map(), + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); }; diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 1072809abf4b..487a71d4f077 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -104,14 +104,20 @@ def get_optional_vars(node, context): @register class Allocate(WithScopeHandler): - """With scope handler T.allocate(extents, dtype, scope, condition)""" + """With scope handler T.allocate(extents, dtype, scope, condition, annotations)""" def __init__(self): - def allocate(extents, dtype, scope, condition=True, span=None): + def allocate(extents, dtype, scope, condition=True, annotations=None, span=None): condition = tvm.runtime.convert(condition) scope = tvm.runtime.convert(scope) return tvm.tir.Allocate( - self.buffer_var, dtype, extents, condition, self.body, span=span + self.buffer_var, + dtype, + extents, + condition, + self.body, + annotations=annotations, + span=span, ) super().__init__(allocate, concise_scope=True, def_symbol=True) @@ -137,7 +143,9 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): + def setup_buffer_var( + extents, dtype, scope, condition=True, annotations=None, span: Span = None + ): """Setup buffer var for a given type.""" buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index d57077f08b52..de200d5eabdd 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -318,13 +318,25 @@ class Allocate(Stmt): body : Stmt The body statement. + annotations: Optional[Mapping[str, Object]] + Additional annotation hints + span : Optional[Span] The location of this itervar in the source code. """ - def __init__(self, buffer_var, dtype, extents, condition, body, span=None): + def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None): + if annotations is None: + annotations = dict() self.__init_handle_by_constructor__( - _ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span # type: ignore + _ffi_api.Allocate, # type: ignore + buffer_var, + dtype, + extents, + condition, + body, + annotations, + span, ) diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index f232994480f8..fa132f079793 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -449,8 +449,16 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto scope = GetPtrStorageScope(op->buffer_var); - doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " - << Print(op->extents) << "), storage_scope = " << scope; + doc << "allocate(" << Print(op->buffer_var) << ", "; + doc << PrintDType(op->dtype) << ", "; + doc << Print(op->extents) << "), storage_scope = " << scope; + if (!op->annotations.empty()) { + std::vector attr_docs; + for (const auto& it : op->annotations) { + attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); + } + doc << ", annotations = {" << PrintSep(attr_docs, Doc::Text(", ")) << "})"; + } if (!is_one(op->condition)) { doc << " if " << Print(op->condition); } diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index fdafdbfee0db..fa74e56f491c 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -769,6 +769,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { if (!is_one(op->condition)) { doc << ", " << Print(op->condition); } + if (!op->annotations.empty()) { + doc << ", annotations={"; + doc << PrintAnnotations(op->annotations); + doc << "}"; + } doc << ") as " << Print(op->buffer_var) << ":"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } else { @@ -777,6 +782,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { if (!is_one(op->condition)) { doc << ", " << Print(op->condition); } + if (!op->annotations.empty()) { + doc << ", annotations={"; + doc << PrintAnnotations(op->annotations); + doc << "}"; + } doc << ")" << Doc::NewLine() << PrintBody(op->body); } TryDeallocVar(op->buffer_var); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index d59c94dc5753..0d42c20c2822 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -333,7 +333,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Span span) { + Stmt body, Map annotations, Span span) { CHECK(IsPointerType(buffer_var->type_annotation, dtype)) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" @@ -354,6 +354,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim node->extents = std::move(extents); node->condition = std::move(condition); node->body = std::move(body); + node->annotations = std::move(annotations); node->span = std::move(span); data_ = std::move(node); } @@ -375,8 +376,8 @@ int32_t AllocateNode::constant_allocation_size(const Array& extents) { TVM_REGISTER_GLOBAL("tir.Allocate") .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, - Stmt body, Span span) { - return Allocate(buffer_var, type, extents, condition, body, span); + Stmt body, Map annotations, Span span) { + return Allocate(buffer_var, type, extents, condition, body, annotations, span); }); TVM_REGISTER_NODE_TYPE(AllocateNode); diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index de94464187b0..fe719ee99693 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -473,5 +473,38 @@ def test_block_blockrealize(): assert output.find("with init()") != -1 +def test_tir_allocate(): + dtype = "int8" + storage_scope = "global" + ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) + a = te.var("buffer", ptype) + allocate = tvm.tir.Allocate( + buffer_var=a, + dtype=dtype, + extents=[2, 2], + condition=tvm.get_global_func("tir.const_true")(dtype, None), + body=tvm.tir.Evaluate(2 + 1), + annotations={ + "attr1": "foo", + "attr2": "bar", + }, + ) + assert allocate.buffer_var == a + assert allocate.dtype == "int8" + assert list(allocate.extents) == [2, 2] + assert allocate.annotations["attr1"] == "foo" + assert allocate.annotations["attr2"] == "bar" + + # make sure we can print using TIRTextPrinter + func = tvm.tir.PrimFunc([], allocate) + output = func.astext() + assert ( + output.find( + 'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})' + ) + != -1 + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 94d4bed2a549..8058b96b024d 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3059,5 +3059,32 @@ def test_while_loop(): tvm.ir.assert_structural_equal(while_loop, rt_func) +# fmt: off +@T.prim_func +def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) +# fmt: on + + +def test_primfunc_with_allocate_annotations(): + func = primfunc_with_allocate_annotations + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))