From 46d152cdaf22d80d15baab62bb578489f38ec84a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 13 Jul 2022 16:38:28 -0700 Subject: [PATCH 1/8] [TIR] Add DeclBuffer node --- include/tvm/tir/stmt.h | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5dd4103e8202..5be1b9626d9c 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -681,6 +681,40 @@ class AllocateConst : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode); }; +/*! \brief Declare a buffer that can be used in the body */ +class DeclBufferNode : public StmtNode { + public: + /*! \brief The buffer being declared */ + Buffer buffer; + /*! \brief The body to be executed */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("body", &body); + v->Visit("span", &span); + } + + bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const { + return equal(buffer, other->buffer) && equal(body, other->body); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(body); + } + + static constexpr const char* _type_key = "tir.DeclBuffer"; + TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode); +}; + +/*! \brief Managed reference to DeclBufferNode */ +class DeclBuffer : public Stmt { + public: + TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode); +}; + /*! * \brief The container of seq statement. * Represent a sequence of statements. From 64e773891c7ff5aa043aad86dae0880a4edd6cf2 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 14 Jul 2022 14:53:46 -0700 Subject: [PATCH 2/8] [TIR] Add IR functors for DeclBuffer --- include/tvm/tir/stmt_functor.h | 4 ++++ python/tvm/tir/stmt.py | 20 ++++++++++++++++++++ src/target/source/codegen_c.cc | 5 +++++ src/target/source/codegen_c.h | 1 + src/tir/ir/stmt.cc | 23 +++++++++++++++++++++++ src/tir/ir/stmt_functor.cc | 16 ++++++++++++++++ tests/cpp/ir_functor_test.cc | 6 ++++-- 7 files changed, 73 insertions(+), 2 deletions(-) diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index fce2e1d67197..49b1f28e5d83 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -89,6 +89,7 @@ class StmtFunctor { virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -116,6 +117,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(WhileNode); IR_STMT_FUNCTOR_DISPATCH(AllocateNode); IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); + IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode); IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode); @@ -159,6 +161,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AllocateConstNode* op) override; + void VisitStmt_(const DeclBufferNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; @@ -260,6 +263,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const WhileNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const AllocateConstNode* op) override; + Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 063439e068a4..4847e377dec1 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -377,6 +377,26 @@ def __init__(self, buffer_var, dtype, extents, data_or_idx, body, annotations=No ) +@tvm._ffi.register_object("tir.DeclBuffer") +class DeclBuffer(Stmt): + """DeclBuffer node. + + Parameters + ---------- + buffer: Buffer + The buffer being declared. + + body: Stmt + The body statement to be executed. + + span: Optional[Span] + The location of this DeclBuffer in the source code. + """ + + def __init__(self, buffer, body, span=None): + self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, body, span) + + @tvm._ffi.register_object("tir.AttrStmt") class AttrStmt(Stmt): """AttrStmt node. diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 3ad7882d792c..9fb6c63b2210 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -29,6 +29,7 @@ #include "../../arith/pattern_match.h" #include "codegen_params.h" +#include "tvm/tir/stmt.h" namespace tvm { namespace codegen { @@ -661,6 +662,10 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) { this->PrintStmt(op->body); } +void CodeGenC::VisitStmt_(const DeclBufferNode* op) { + this->PrintStmt(op->body); +} + void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead."; } diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 696ec62c5870..0af24dfdc066 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -166,6 +166,7 @@ class CodeGenC : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const AllocateConstNode* op) override; + void VisitStmt_(const DeclBufferNode* op) override; /*! * \brief Print expr representing the thread tag diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 2b337520a249..524204f3d394 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -508,6 +508,29 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->body); }); +// DeclBuffer +DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) { + ObjectPtr node = make_object(); + node->buffer = std::move(buffer); + node->body = std::move(body); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) { + return DeclBuffer(buffer, body, span); +}); + +TVM_REGISTER_NODE_TYPE(DeclBufferNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "decl_buffer " << op->buffer << "\n"; + p->stream << op->body; + }); + // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope, Span span) { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index c0abf953eec2..f7db92bdc9d4 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -63,6 +63,10 @@ void StmtVisitor::VisitStmt_(const AllocateConstNode* op) { this->VisitStmt(op->body); } +void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { + this->VisitStmt(op->body); +} + void StmtVisitor::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } @@ -336,6 +340,18 @@ Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) { } } +Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { + Stmt body = this->VisitStmt(op->body); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->body = std::move(body); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 33b145d3a41d..2909915c3288 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -181,6 +181,7 @@ TEST(IRF, StmtVisitor) { DataType dtype = DataType::Float(32); Var buf_var("b", PointerType(PrimType(dtype))); Buffer buffer = decl_buffer({16}); + body = DeclBuffer(buffer, std::move(body)); BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)}); MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region); @@ -309,6 +310,7 @@ TEST(IRF, StmtMutator) { DataType dtype = DataType::Float(32); Var buf_var("b", PointerType(PrimType(dtype))); Buffer buffer = decl_buffer({16}); + body = DeclBuffer(buffer, std::move(body)); BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)}); MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region); // construct block and block_realize @@ -318,8 +320,8 @@ TEST(IRF, StmtMutator) { body = v(std::move(block_realize)); // the body should be changed Block new_block = body.as()->block; - ICHECK(new_block->body.as()->extents[1].same_as(x)); - ICHECK(new_block->init.as()->extents[1].same_as(x)); + ICHECK(new_block->body.as()->body.as()->extents[1].same_as(x)); + ICHECK(new_block->init.as()->body.as()->extents[1].same_as(x)); ICHECK(new_block->reads[0]->region[0]->min.same_as(x)); ICHECK(new_block->writes[0]->region[0]->min.same_as(x)); ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); From aa1626ff2b6577554f777fa324eaafb37d6c0c16 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 2 Aug 2022 13:41:59 -0700 Subject: [PATCH 3/8] [TVMScript] Add printer and parser for DeclBuffer --- python/tvm/script/tir/__init__.pyi | 12 ++++ python/tvm/script/tir/special_stmt.py | 57 +++++++++++++++++++ src/printer/text_printer.h | 1 + src/printer/tir_text_printer.cc | 12 ++++ src/printer/tvmscript_printer.cc | 21 +++++++ .../unittest/test_tvmscript_roundtrip.py | 12 ++++ 6 files changed, 115 insertions(+) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index f03c5c06da3d..a62fb102bec5 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -187,6 +187,18 @@ def match_buffer( buffer_type: str = "default", axis_separators: Optional[List[int]] = None, ) -> Buffer: ... +def decl_buffer( + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + axis_separators: Optional[List[int]] = None, +) -> Buffer: ... def buffer_decl( shape: Sequence[Union[PrimExpr, int]], dtype: str = "float32", diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 15502055b7fc..f0ce33d9d5d1 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -239,6 +239,63 @@ def buffer_decl( super().__init__(buffer_decl, def_symbol=True) +@register +class DeclBuffer(SpecialStmt): + """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align, + offset_factor, buffer_type, axis_separators) + Example + ------- + .. code-block:: python + A = T.decl_buffer((128, 128), dtype="float32") + """ + + def __init__(self): + def decl_buffer( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None, + span=None, + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`decl_buffer` must be assigned to a single buffer, e.g. A = decl_buffer(...)", + self.node.span, + ) + + if strides is None: + strides = [] + align = convert_to_int(align, "align", self.context.report_error, self.node.span) + offset_factor = convert_to_int( + offset_factor, "offset_factor", self.context.report_error, self.node.span + ) + buffer_name: str = self.node.lhs[0].id.name + buffer = tvm.tir.decl_buffer( + shape, + dtype, + buffer_name, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + span=span, + ) + self.context.update_symbol(buffer_name, buffer, self.node) + return buffer + + super().__init__(decl_buffer, def_symbol=True) + + @register class AllocBuffer(SpecialStmt): """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align, diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 05a00e3305e1..2dc0997f82ec 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -353,6 +353,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const ProducerRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const AllocateConstNode* op) override; + Doc VisitStmt_(const DeclBufferNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fe829016b6b5..afc91863d66d 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -557,6 +557,18 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) { + Doc doc; + doc << AllocBuf(op->buffer) << " = decl_buffer(" << PrintDType(op->buffer->dtype) + << Print(op->buffer->shape) << ")" << Doc::NewLine(); + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << PrintBody(op->then_case); diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index aaebc7409f29..c3e9dc2ecb94 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -245,6 +245,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const BufferRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const AllocateConstNode* op) override; + Doc VisitStmt_(const DeclBufferNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const ForNode* op) override; @@ -1161,6 +1162,26 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { return doc; } +Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) { + const Buffer& buffer = op->buffer; + auto storage_scope = GetPtrStorageScope(buffer->data); + Doc func_call; + func_call << tir_prefix_ << ".decl_buffer(" << Print(buffer->shape) << ", " << PrintDType(buffer->dtype) + << ", " << Print(storage_scope); + func_call << ")"; + + Doc doc; + if (current_num_ != num_child_ - 1) { + doc << "with " << func_call << " as " << Print(buffer) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << Print(buffer) << " = " << func_call << Doc::NewLine(); + doc << PrintBody(op->body); + } + TryDeallocVar(buffer->data); + return doc; +} + Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << ":"; diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 528357339c72..3591f1250864 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3303,6 +3303,17 @@ def func(out_ret_value: T.Ptr[T.void]): return func +def decl_buffer(): + @T.prim_func + def func(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + A_flattened = T.decl_buffer(data=A.data, shape=(256,), dtype="float32") + B_flattened = T.decl_buffer(data=B.data, shape=(256,), dtype="float32") + for i in range(256): + B_flattened[i] = A_flattened[i] + T.float32(1.0) + + return func + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3342,6 +3353,7 @@ def func(out_ret_value: T.Ptr[T.void]): buffer_ramp_access_as_slice_index, let_expression, void_ptr, + decl_buffer, ) From 60dd61f1e253123bae11ff70b19a4e284a8e0009 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 3 Aug 2022 14:15:18 -0700 Subject: [PATCH 4/8] Update printer --- src/printer/tir_text_printer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index afc91863d66d..c5902cadda51 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -559,8 +559,8 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) { Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) { Doc doc; - doc << AllocBuf(op->buffer) << " = decl_buffer(" << PrintDType(op->buffer->dtype) - << Print(op->buffer->shape) << ")" << Doc::NewLine(); + doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->var) << ", " + << PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << Doc::NewLine(); if (op->body->IsInstance()) { doc << PrintBody(op->body); } else { From 2794f06c05aca410f00b02c223165ed48e5ec887 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 3 Aug 2022 15:06:04 -0700 Subject: [PATCH 5/8] Update printer --- python/tvm/script/tir/scope_handler.py | 81 ++++++++++++++++++++++++++ python/tvm/script/tir/special_stmt.py | 57 ------------------ python/tvm/tir/__init__.py | 1 + src/printer/tir_text_printer.cc | 2 +- src/printer/tvmscript_printer.cc | 12 ++-- 5 files changed, 88 insertions(+), 65 deletions(-) diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 92aaf8b4d992..da7545c9a9e9 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -224,6 +224,87 @@ def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None context.update_symbol(name, self.buffer, node) +@register +class DeclBuffer(WithScopeHandler): + """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align, + offset_factor, buffer_type, axis_separators) + Example + ------- + .. code-block:: python + A = T.decl_buffer((128, 128), dtype="float32") + """ + + def __init__(self): + def decl_buffer( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None, + span=None, + ): + return tvm.tir.DeclBuffer(self.buffer, self.body, span=span) + + super().__init__(decl_buffer, concise_scope=True, def_symbol=True) + + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): + # define buffer vars in symbol table + if isinstance(node, synr.ast.With): + vars = WithScopeHandler.get_optional_vars(node, context) + if len(vars) != 1: + context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span) + name = vars[0].id.name + var_span = vars[0].id.span + elif isinstance(node, synr.ast.Assign): + if len(node.lhs) != 1: + context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span) + name = node.lhs[0].id.name + var_span = node.lhs[0].id.span + else: + raise Exception("Internal Bug") + + def setup_buffer( + shape, + dtype, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + span: Span = None, + ): + self.buffer = tvm.tir.decl_buffer( + shape=shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope, + data_alignment=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + span=span, + ) + + setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer, node) + + @register class LaunchThread(WithScopeHandler): """With scope handler T.launch_thread(env_var, extent)""" diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index f0ce33d9d5d1..15502055b7fc 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -239,63 +239,6 @@ def buffer_decl( super().__init__(buffer_decl, def_symbol=True) -@register -class DeclBuffer(SpecialStmt): - """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align, - offset_factor, buffer_type, axis_separators) - Example - ------- - .. code-block:: python - A = T.decl_buffer((128, 128), dtype="float32") - """ - - def __init__(self): - def decl_buffer( - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - axis_separators=None, - span=None, - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`decl_buffer` must be assigned to a single buffer, e.g. A = decl_buffer(...)", - self.node.span, - ) - - if strides is None: - strides = [] - align = convert_to_int(align, "align", self.context.report_error, self.node.span) - offset_factor = convert_to_int( - offset_factor, "offset_factor", self.context.report_error, self.node.span - ) - buffer_name: str = self.node.lhs[0].id.name - buffer = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span=span, - ) - self.context.update_symbol(buffer_name, buffer, self.node) - return buffer - - super().__init__(decl_buffer, def_symbol=True) - - @register class AllocBuffer(SpecialStmt): """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align, diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index a3798ccab44e..c64b7dfe713d 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -36,6 +36,7 @@ Allocate, AllocateConst, AttrStmt, + DeclBuffer, ) from .stmt import ProducerRealize, SeqStmt diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index c5902cadda51..894a9cec1e2a 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -559,7 +559,7 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) { Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) { Doc doc; - doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->var) << ", " + doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", " << PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << Doc::NewLine(); if (op->body->IsInstance()) { doc << PrintBody(op->body); diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index c3e9dc2ecb94..0926858ae0db 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1164,21 +1164,19 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) { const Buffer& buffer = op->buffer; - auto storage_scope = GetPtrStorageScope(buffer->data); + buf_not_in_headers_.insert(buffer.get()); + Doc buffer_name = Print(op->buffer); Doc func_call; - func_call << tir_prefix_ << ".decl_buffer(" << Print(buffer->shape) << ", " << PrintDType(buffer->dtype) - << ", " << Print(storage_scope); - func_call << ")"; + func_call << tir_prefix_ << ".decl_buffer(" << memo_buf_decl_.at(buffer) << ")"; Doc doc; if (current_num_ != num_child_ - 1) { - doc << "with " << func_call << " as " << Print(buffer) << ":"; + doc << "with " << func_call << " as " << buffer_name << ":"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } else { - doc << Print(buffer) << " = " << func_call << Doc::NewLine(); + doc << buffer_name << " = " << func_call << Doc::NewLine(); doc << PrintBody(op->body); } - TryDeallocVar(buffer->data); return doc; } From 53fd4104319169e763358201f243105c84e2475f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 3 Aug 2022 15:07:33 -0700 Subject: [PATCH 6/8] Add test case --- tests/python/unittest/test_tvmscript_roundtrip.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 3591f1250864..0a2cec6011ef 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3308,8 +3308,9 @@ def decl_buffer(): def func(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: A_flattened = T.decl_buffer(data=A.data, shape=(256,), dtype="float32") B_flattened = T.decl_buffer(data=B.data, shape=(256,), dtype="float32") + C_alias = T.decl_buffer(data=A_flattened.data, shape=(256,), dtype="float32") for i in range(256): - B_flattened[i] = A_flattened[i] + T.float32(1.0) + B_flattened[i] = A_flattened[i] + C_alias[i] + T.float32(1.0) return func From 70da72a2078cf581402104b009be640273814bb3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 3 Aug 2022 15:27:15 -0700 Subject: [PATCH 7/8] lint --- src/target/source/codegen_c.cc | 4 +--- src/tir/ir/stmt_functor.cc | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 9fb6c63b2210..0646a3a5d322 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -662,9 +662,7 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) { this->PrintStmt(op->body); } -void CodeGenC::VisitStmt_(const DeclBufferNode* op) { - this->PrintStmt(op->body); -} +void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); } void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead."; diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index f7db92bdc9d4..c75eb52f9296 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -63,9 +63,7 @@ void StmtVisitor::VisitStmt_(const AllocateConstNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { - this->VisitStmt(op->body); -} +void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; From 4c5dbce5d740d151e743a752dc207f9d69e3edd6 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 5 Aug 2022 11:01:08 -0700 Subject: [PATCH 8/8] fix --- src/target/source/codegen_c.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 0646a3a5d322..3fe7fa50d3cf 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -29,7 +29,6 @@ #include "../../arith/pattern_match.h" #include "codegen_params.h" -#include "tvm/tir/stmt.h" namespace tvm { namespace codegen {