diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5dd4103e8202..5be1b9626d9c 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -681,6 +681,40 @@ class AllocateConst : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode); }; +/*! \brief Declare a buffer that can be used in the body */ +class DeclBufferNode : public StmtNode { + public: + /*! \brief The buffer being declared */ + Buffer buffer; + /*! \brief The body to be executed */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("body", &body); + v->Visit("span", &span); + } + + bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const { + return equal(buffer, other->buffer) && equal(body, other->body); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(body); + } + + static constexpr const char* _type_key = "tir.DeclBuffer"; + TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode); +}; + +/*! \brief Managed reference to DeclBufferNode */ +class DeclBuffer : public Stmt { + public: + TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode); +}; + /*! * \brief The container of seq statement. * Represent a sequence of statements. diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index fce2e1d67197..49b1f28e5d83 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -89,6 +89,7 @@ class StmtFunctor { virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -116,6 +117,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(WhileNode); IR_STMT_FUNCTOR_DISPATCH(AllocateNode); IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); + IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode); IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode); @@ -159,6 +161,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AllocateConstNode* op) override; + void VisitStmt_(const DeclBufferNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; @@ -260,6 +263,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const WhileNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const AllocateConstNode* op) override; + Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index f03c5c06da3d..a62fb102bec5 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -187,6 +187,18 @@ def match_buffer( buffer_type: str = "default", axis_separators: Optional[List[int]] = None, ) -> Buffer: ... +def decl_buffer( + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + axis_separators: Optional[List[int]] = None, +) -> Buffer: ... def buffer_decl( shape: Sequence[Union[PrimExpr, int]], dtype: str = "float32", diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 92aaf8b4d992..da7545c9a9e9 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -224,6 +224,87 @@ def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None context.update_symbol(name, self.buffer, node) +@register +class DeclBuffer(WithScopeHandler): + """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align, + offset_factor, buffer_type, axis_separators) + Example + ------- + .. code-block:: python + A = T.decl_buffer((128, 128), dtype="float32") + """ + + def __init__(self): + def decl_buffer( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None, + span=None, + ): + return tvm.tir.DeclBuffer(self.buffer, self.body, span=span) + + super().__init__(decl_buffer, concise_scope=True, def_symbol=True) + + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): + # define buffer vars in symbol table + if isinstance(node, synr.ast.With): + vars = WithScopeHandler.get_optional_vars(node, context) + if len(vars) != 1: + context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span) + name = vars[0].id.name + var_span = vars[0].id.span + elif isinstance(node, synr.ast.Assign): + if len(node.lhs) != 1: + context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span) + name = node.lhs[0].id.name + var_span = node.lhs[0].id.span + else: + raise Exception("Internal Bug") + + def setup_buffer( + shape, + dtype, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + span: Span = None, + ): + self.buffer = tvm.tir.decl_buffer( + shape=shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope, + data_alignment=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + span=span, + ) + + setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer, node) + + @register class LaunchThread(WithScopeHandler): """With scope handler T.launch_thread(env_var, extent)""" diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index a3798ccab44e..c64b7dfe713d 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -36,6 +36,7 @@ Allocate, AllocateConst, AttrStmt, + DeclBuffer, ) from .stmt import ProducerRealize, SeqStmt diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 063439e068a4..4847e377dec1 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -377,6 +377,26 @@ def __init__(self, buffer_var, dtype, extents, data_or_idx, body, annotations=No ) +@tvm._ffi.register_object("tir.DeclBuffer") +class DeclBuffer(Stmt): + """DeclBuffer node. + + Parameters + ---------- + buffer: Buffer + The buffer being declared. + + body: Stmt + The body statement to be executed. + + span: Optional[Span] + The location of this DeclBuffer in the source code. + """ + + def __init__(self, buffer, body, span=None): + self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, body, span) + + @tvm._ffi.register_object("tir.AttrStmt") class AttrStmt(Stmt): """AttrStmt node. diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 05a00e3305e1..2dc0997f82ec 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -353,6 +353,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const ProducerRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const AllocateConstNode* op) override; + Doc VisitStmt_(const DeclBufferNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fe829016b6b5..894a9cec1e2a 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -557,6 +557,18 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) { + Doc doc; + doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", " + << PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << Doc::NewLine(); + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << PrintBody(op->then_case); diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index aaebc7409f29..0926858ae0db 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -245,6 +245,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const BufferRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const AllocateConstNode* op) override; + Doc VisitStmt_(const DeclBufferNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const ForNode* op) override; @@ -1161,6 +1162,24 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { return doc; } +Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) { + const Buffer& buffer = op->buffer; + buf_not_in_headers_.insert(buffer.get()); + Doc buffer_name = Print(op->buffer); + Doc func_call; + func_call << tir_prefix_ << ".decl_buffer(" << memo_buf_decl_.at(buffer) << ")"; + + Doc doc; + if (current_num_ != num_child_ - 1) { + doc << "with " << func_call << " as " << buffer_name << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << buffer_name << " = " << func_call << Doc::NewLine(); + doc << PrintBody(op->body); + } + return doc; +} + Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << ":"; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 3ad7882d792c..3fe7fa50d3cf 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -661,6 +661,8 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) { this->PrintStmt(op->body); } +void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); } + void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead."; } diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 696ec62c5870..0af24dfdc066 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -166,6 +166,7 @@ class CodeGenC : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const AllocateConstNode* op) override; + void VisitStmt_(const DeclBufferNode* op) override; /*! * \brief Print expr representing the thread tag diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 2b337520a249..524204f3d394 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -508,6 +508,29 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->body); }); +// DeclBuffer +DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) { + ObjectPtr node = make_object(); + node->buffer = std::move(buffer); + node->body = std::move(body); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) { + return DeclBuffer(buffer, body, span); +}); + +TVM_REGISTER_NODE_TYPE(DeclBufferNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "decl_buffer " << op->buffer << "\n"; + p->stream << op->body; + }); + // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope, Span span) { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index c0abf953eec2..c75eb52f9296 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -63,6 +63,8 @@ void StmtVisitor::VisitStmt_(const AllocateConstNode* op) { this->VisitStmt(op->body); } +void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { this->VisitStmt(op->body); } + void StmtVisitor::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } @@ -336,6 +338,18 @@ Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) { } } +Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { + Stmt body = this->VisitStmt(op->body); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->body = std::move(body); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 33b145d3a41d..2909915c3288 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -181,6 +181,7 @@ TEST(IRF, StmtVisitor) { DataType dtype = DataType::Float(32); Var buf_var("b", PointerType(PrimType(dtype))); Buffer buffer = decl_buffer({16}); + body = DeclBuffer(buffer, std::move(body)); BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)}); MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region); @@ -309,6 +310,7 @@ TEST(IRF, StmtMutator) { DataType dtype = DataType::Float(32); Var buf_var("b", PointerType(PrimType(dtype))); Buffer buffer = decl_buffer({16}); + body = DeclBuffer(buffer, std::move(body)); BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)}); MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region); // construct block and block_realize @@ -318,8 +320,8 @@ TEST(IRF, StmtMutator) { body = v(std::move(block_realize)); // the body should be changed Block new_block = body.as()->block; - ICHECK(new_block->body.as()->extents[1].same_as(x)); - ICHECK(new_block->init.as()->extents[1].same_as(x)); + ICHECK(new_block->body.as()->body.as()->extents[1].same_as(x)); + ICHECK(new_block->init.as()->body.as()->extents[1].same_as(x)); ICHECK(new_block->reads[0]->region[0]->min.same_as(x)); ICHECK(new_block->writes[0]->region[0]->min.same_as(x)); ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 528357339c72..0a2cec6011ef 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3303,6 +3303,18 @@ def func(out_ret_value: T.Ptr[T.void]): return func +def decl_buffer(): + @T.prim_func + def func(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + A_flattened = T.decl_buffer(data=A.data, shape=(256,), dtype="float32") + B_flattened = T.decl_buffer(data=B.data, shape=(256,), dtype="float32") + C_alias = T.decl_buffer(data=A_flattened.data, shape=(256,), dtype="float32") + for i in range(256): + B_flattened[i] = A_flattened[i] + C_alias[i] + T.float32(1.0) + + return func + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3342,6 +3354,7 @@ def func(out_ret_value: T.Ptr[T.void]): buffer_ramp_access_as_slice_index, let_expression, void_ptr, + decl_buffer, )