diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 5ee847e2f010..ced060b8cc86 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -423,6 +423,12 @@ TVM_DLL Pass CompactBufferAllocation(); */ TVM_DLL Pass LegalizePackedCalls(); +/*! + * \brief Remove match buffers inside the block. Also, it will validate the binding. + * \return The pass. + */ +TVM_DLL Pass LowerMatchBuffer(); + /*! * \brief Flatten the multi-dimensional BufferLoad and BufferStore * to single dimensional Load/Store. Also remove Block to diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index ae3e9d885f1a..44c92b792f12 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -57,7 +57,7 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # match_buffers of the block, # which bind a sub-region of source buffer into a new buffer - D = tir.match_buffer_region(C[vi, vj]) + D = tir.match_buffer(C[vi, vj], ()) # init part of the block, executed when all reduce axes are the beginning value with tir.init(): @@ -65,13 +65,13 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # block body CC[0, 0] = A[vi, vk] * B[vj, vk] - D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] + D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] """ alloc_buffers: List[Buffer] = [] """List[Buffer]: list of tir.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] - """List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature""" + """List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature""" iter_bindings: Mapping[Var, PrimExpr] = {} """Mapping[Var, PrimExpr]: map of block iter var to its values""" reads: Optional[List[BufferSlice]] = None diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 49f71041590b..9acf21b6ba3a 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -784,10 +784,11 @@ def transform_Slice(self, node): def transform_Subscript(self, node): """Array access visitor. - By now only 2 types of Subscript are supported: + By now only 3 types of Subscript are supported: 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) Var[index] Buffer element access() 2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...])) + 3. Array[index], Buffer element access """ symbol = self.transform(node.params[0]) @@ -812,6 +813,25 @@ def transform_Subscript(self, node): return BufferSlice( symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span) ) + elif isinstance(symbol, tvm.container.Array): + if len(indexes) > 1: + self.report_error( + "Array access should be one-dimension access, but the indices are " + + str(indexes), + node.span, + ) + index = indexes[0] + if not isinstance(index, (int, tvm.tir.expr.IntImm)): + self.report_error( + "Array access index expected int or IntImm, but got " + type(index), + node.span, + ) + if int(index) >= len(symbol): + self.report_error( + f"Array access out of bound, size: {len(symbol)}, got index {index}.", + node.span, + ) + return symbol[int(index)] else: self.report_error( f"Cannot subscript from a {type(symbol).__name__}. Only variables and " diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index b8cd887ea362..25af7635742b 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -96,12 +96,24 @@ def handle( @register class MatchBuffer(SpecialStmt): - """Special Stmt match_buffer(var, shape, dtype, data, strides, elem_offset, scope, align, + """Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align, offset_factor, buffer_type) + + Note + ---- + This Special Stmt will perform different behavior depends on the type of param. + If the param is a var in function parameter, it will create a buffer from DLTensor. + Else if the param is a subregion of other buffers, then create a subregion match inside a block. + Example ------- + Match buffer from function parameter .. code-block:: python A = tir.match_buffer(a, (128, 128), dtype="float32") + + Match buffer from Buffer subregion + .. code-block:: python + A = tir.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") """ def __init__(self): @@ -123,10 +135,6 @@ def match_buffer( "match_buffer must be assigned to a buffer, e.g. A = match_buffer(...)", self.node.span, ) - if param not in self.context.func_params: - self.context.report_error( - "Can not bind non-input param to buffer", self.node.rhs.params[0].span - ) if strides is None: strides = [] align = convert_to_int(align, "align", self.context.report_error, self.node.span) @@ -146,7 +154,23 @@ def match_buffer( buffer_type, span=span, ) - self.context.func_buffer_map[param] = buffer + if isinstance(param, tvm.tir.Var): + if param not in self.context.func_params: + self.context.report_error( + "Can not bind non-input param to buffer", self.node.rhs.params[0].span + ) + self.context.func_buffer_map[param] = buffer + elif isinstance(param, BufferSlice): + buffer_region = buffer_slice_to_region(param) + self.context.current_block_scope().match_buffers.append( + tvm.tir.MatchBufferRegion(buffer, buffer_region) + ) + else: + self.context.report_error( + "The source of match_buffer expected Var or BufferSlice, but got " + + str(type(param)), + self.node.rhs.params[0].span, + ) self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) super().__init__(match_buffer, def_symbol=True) @@ -414,68 +438,6 @@ def where(predicate, span=None): super().__init__(where, def_symbol=False) -@register -class BlockMatchBufferRegion(SpecialStmt): - """Special function match_buffer_region(source, strides, elem_offset, align, offset_factor) - - Example - ------- - .. code-block:: python - - B = tir.match_buffer_region(A[0: 4]) - """ - - def __init__(self): - def match_buffer_region( - source, - strides=None, - elem_offset=None, - align=-1, - offset_factor=0, - span=None, - ): - assert self.context, "call 'exit_scope' before 'enter_scope'" - if not isinstance(self.node, ast.Assign): - self.context.report_error( - "match_buffer_region must be assigned to a buffer, " - + "e.g. A = match_buffer_region(...)", - self.node.span, - ) - - if strides is None: - strides = [] - align = convert_to_int(align, "align", self.context.report_error, self.node.span) - offset_factor = convert_to_int( - offset_factor, "offset_factor", self.context.report_error, self.node.span - ) - - if not isinstance(source, BufferSlice): - self.context.report_error( - "match_buffer_region needs a buffer region as source", - span=span, - ) - buffer_region = buffer_slice_to_region(source) - shape = [r.extent for r in buffer_region.region] - buffer = tvm.tir.decl_buffer( - shape, - buffer_region.buffer.dtype, - self.node.lhs.id.name, - data=None, - strides=strides, - elem_offset=elem_offset, - scope=buffer_region.buffer.scope(), - data_alignment=align, - offset_factor=offset_factor, - span=span, - ) - self.context.current_block_scope().match_buffers.append( - tvm.tir.MatchBufferRegion(buffer, buffer_region) - ) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) - - super().__init__(match_buffer_region, def_symbol=True) - - @register class VarDef(SpecialStmt): """Special function for defining a Var""" diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index b445bcb25005..6dddd7b119a0 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -255,7 +255,7 @@ def decl_buffer( dtype = "float32" if dtype is None else dtype strides = () if strides is None else strides if offset_factor != 0 and elem_offset is None: - shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" + shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32" elem_offset = Var("%s_elem_offset" % name, shape_dtype) if data is None: # Bool is represented as uint1 in the IR, but stored as int8 diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 1e5c303fa17c..3e47eb5a4254 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -644,6 +644,17 @@ def CompactBufferAllocation(): return _ffi_api.CompactBufferAllocation() # type: ignore +def LowerMatchBuffer(): + """Remove match buffers inside the block. Also, it will validate the binding. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerMatchBuffer() # type: ignore + + def FlattenBuffer(): """Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store. Also remove Block to diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 1591e875a4b3..9795bf3bc704 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -222,6 +222,7 @@ Array CreatePassList(bool disable_loop_partition, bool for pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::FlattenBuffer()); } pass_list.push_back(tir::transform::BF16Legalize()); diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index de03877afe3a..f232994480f8 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -583,8 +583,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { << Print(alloc_buf->shape) << ")" << Doc::NewLine(); } for (const auto& match_buf : block_op->match_buffers) { - body << AllocBuf(match_buf->buffer) << " = match_buffer_region(" << Print(match_buf->source) - << ")" << Doc::NewLine(); + body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")" + << Doc::NewLine(); } if (block_op->init.defined()) { Doc init_block; diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 6616731e6578..cc7536b48cfd 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -337,29 +337,8 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { const Buffer& buf = op->buffer; buf_not_in_headers.insert(buf.get()); - Doc doc = Print(op->buffer) << " = tir.match_buffer_region(" << Print(op->source); - if (!buf->strides.empty()) { - doc << ", strides=" << Print(buf->strides); - } - if (buf->offset_factor != 0 && buf->elem_offset->IsInstance()) { - Var elem_offset = Downcast(buf->elem_offset); - if (memo_var_.find(elem_offset) != memo_var_.end()) { - doc << ", elem_offset=" << Print(buf->elem_offset); - } else { - // implicitly define elem_offset - memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset"); - var_not_in_headers.insert(elem_offset.get()); - } - } else { - doc << ", elem_offset=" << Print(buf->elem_offset); - } - if (buf->data_alignment != -1) { - doc << ", align=" << buf->data_alignment; - } - if (buf->offset_factor != 0) { - doc << ", offset_factor=" << buf->offset_factor; - } - doc << ")"; + Doc doc = Print(op->buffer) << " = tir.match_buffer(" << Print(op->source) << ", " + << memo_buf_decl_[op->buffer] << ")"; return doc; } diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index b1da536f1dad..8f87ef920784 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -26,6 +26,7 @@ #include #include +#include "../transforms/ir_utils.h" namespace tvm { namespace tir { @@ -65,8 +66,12 @@ class BlockReadWriteDetector : public StmtExprVisitor { std::vector> read_regions_; /*! \brief The write regions of the current block */ std::vector> write_regions_; + /*! \brief The opaque regions of the current block */ + std::vector> opaque_regions_; /*! \brief The outside buffer data mapping to its buffer */ Map buffer_var_map_; + /*! \brief The target buffer var mapping to its matching */ + std::unordered_map match_buffers_; /*! \brief The analyzer for simplifying*/ arith::Analyzer analyzer_; @@ -78,14 +83,18 @@ class BlockReadWriteDetector : public StmtExprVisitor { * \param region The provided region */ void Update(std::vector* buffers, std::vector>* regions, - const Buffer& buffer, const std::vector& region); + Buffer buffer, std::vector region); /*! \brief Helper function to collect access regions. */ Array CollectRegions(const std::vector& buffers, const std::vector>& regions); - /*! \brief Helper function to add a opaque buffer. */ - void AddOpaque(const Var& buffer_var); + /*! \brief Helper function to convert matched access region to source region. */ + std::vector ConvertMatchedRegion(const MatchBufferRegion& match_buffer, + const std::vector& int_sets) const; + + /*! \brief Helper function to update a opaque access. */ + void UpdateOpaque(const Var& buffer_var); void VisitStmt_(const ForNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; @@ -97,8 +106,13 @@ class BlockReadWriteDetector : public StmtExprVisitor { }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { - ICHECK(stmt.as() != nullptr) - << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); + const auto* block = stmt.as(); + ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); + for (const MatchBufferRegion& match_buffer : block->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + match_buffers_[target_var.get()] = match_buffer; + buffer_var_map_.Set(target_var, match_buffer->buffer); + } StmtExprVisitor::operator()(stmt); } @@ -111,18 +125,13 @@ Array BlockReadWriteDetector::CollectWrites() { } Array BlockReadWriteDetector::CollectOpaques() { - Array res; - res.reserve(opaque_buffers_.size()); - for (const Buffer& buffer : opaque_buffers_) { - res.push_back(BufferRegion::FullRegion(buffer)); - } - return res; + return CollectRegions(opaque_buffers_, opaque_regions_); } -void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { AddOpaque(GetRef(op)); } +void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef(op)); } void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { - AddOpaque(op->buffer_var); + UpdateOpaque(op->buffer_var); ExprVisitor::VisitExpr_(op); } @@ -143,7 +152,7 @@ void BlockReadWriteDetector::VisitStmt_(const ForNode* op) { } void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { - AddOpaque(op->buffer_var); + UpdateOpaque(op->buffer_var); StmtVisitor::VisitStmt_(op); } @@ -184,11 +193,39 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { } } +std::vector BlockReadWriteDetector::ConvertMatchedRegion( + const MatchBufferRegion& match_buffer, const std::vector& int_sets) const { + const Buffer& buffer = match_buffer->buffer; + + Region region; + region.reserve(int_sets.size()); + ICHECK_EQ(buffer->shape.size(), int_sets.size()); + for (size_t i = 0; i < int_sets.size(); ++i) { + const tvm::arith::IntSet& int_set = int_sets[i]; + region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); + } + + region = ConvertRegion(match_buffer, region); + + std::vector result; + result.reserve(region.size()); + for (const Range& range : region) { + result.push_back(arith::EvalSet(range, dom_map_)); + } + return result; +} + void BlockReadWriteDetector::Update(std::vector* buffers, - std::vector>* regions, - const Buffer& buffer, - const std::vector& region) { + std::vector>* regions, Buffer buffer, + std::vector region) { if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return; + // Handle match_buffer remap + auto it = match_buffers_.find(buffer->data.get()); + if (it != match_buffers_.end()) { + const MatchBufferRegion& match_buffer = it->second; + buffer = match_buffer->source->buffer; + region = ConvertMatchedRegion(match_buffer, std::move(region)); + } ICHECK_EQ(buffers->size(), regions->size()) << " Expected the buffer and regions to have the same size "; for (size_t i = 0; i < regions->size(); ++i) { @@ -200,8 +237,8 @@ void BlockReadWriteDetector::Update(std::vector* buffers, return; } } - buffers->push_back(buffer); - regions->push_back(region); + buffers->push_back(std::move(buffer)); + regions->push_back(std::move(region)); } Array BlockReadWriteDetector::CollectRegions( @@ -213,8 +250,9 @@ Array BlockReadWriteDetector::CollectRegions( for (size_t i = 0; i < regions.size(); ++i) { Array region; region.reserve(regions[i].size()); + ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { - tvm::arith::IntSet range = regions[i][j]; + const tvm::arith::IntSet& range = regions[i][j]; region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); } res.push_back(BufferRegion(buffers[i], region)); @@ -222,14 +260,18 @@ Array BlockReadWriteDetector::CollectRegions( return res; } -void BlockReadWriteDetector::AddOpaque(const Var& buffer_var) { +void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { auto it = buffer_var_map_.find(buffer_var); if (it != buffer_var_map_.end()) { const Buffer& buffer = (*it).second; - for (const Buffer& opaque_buffer : opaque_buffers_) { - if (buffer.same_as(opaque_buffer)) return; + const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); + const Region& region = buffer_region->region; + std::vector int_set; + int_set.reserve(region.size()); + for (const Range& range : region) { + int_set.push_back(arith::EvalSet(range, dom_map_)); } - opaque_buffers_.push_back(buffer); + Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); } } diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 6f2622f3a61e..e680d689735d 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -85,9 +85,17 @@ class LCADetector : public StmtExprVisitor { for (const Buffer& buf : op->alloc_buffers) { buffer_var_map_.emplace(buf->data.get(), buf.get()); } + const ScopeInfo* parent_scope = ancestor_scopes_.back(); auto* current_scope = arena_.make(parent_scope, op, n); + ancestor_scopes_.push_back(current_scope); + // Update match_buffers + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + UpdateBufferLCA(match_buffer->source->buffer.get()); + match_buffers_.insert(match_buffer->buffer.get()); + } + StmtExprVisitor::VisitStmt_(op); ancestor_scopes_.pop_back(); } @@ -129,8 +137,11 @@ class LCADetector : public StmtExprVisitor { } void UpdateBufferLCA(const BufferNode* buffer) { - const ScopeInfo*& lca = buffer_lca_[buffer]; - lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); + if (match_buffers_.find(buffer) == match_buffers_.end()) { + // Ingore buffer created by block match_buffer + const ScopeInfo*& lca = buffer_lca_[buffer]; + lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); + } } static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { @@ -164,6 +175,8 @@ class LCADetector : public StmtExprVisitor { std::unordered_map buffer_lca_ = {}; /*! \brief The map from Buffer data to the Buffer. */ std::unordered_map buffer_var_map_ = {}; + /*! \brief The match buffers inside blocks. */ + std::unordered_set match_buffers_ = {}; /*! \brief Internal arena. */ support::Arena arena_; }; diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index c15b3bb47bf4..f265a8ae2b1b 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -36,15 +36,12 @@ namespace tir { /*! \brief Generate surrounding loops automatically */ class ScriptCompleter : public StmtMutator { public: - explicit ScriptCompleter(Map* buffer_var_map, bool contain_root) - : buffer_var_map_(buffer_var_map), contain_root_(contain_root) {} + explicit ScriptCompleter(Map* buffer_var_map) : buffer_var_map_(buffer_var_map) {} /*! \brief Whether the stmt contains at least one block. */ bool contains_block = false; private: Map* buffer_var_map_; - bool contain_root_; - bool visited_root_ = false; Stmt VisitStmt_(const BlockRealizeNode* op) override { contains_block = true; Stmt body = StmtMutator::VisitStmt_(op); @@ -65,17 +62,23 @@ class ScriptCompleter : public StmtMutator { } Stmt VisitStmt_(const BlockNode* op) override { - bool is_root_block = contain_root_ && !visited_root_; - visited_root_ = true; // Buffers allocated in the block can be accessed by its body. for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->Set(alloc_buffer->data, alloc_buffer); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->Set(target_buffer->data, target_buffer); + } Block block = Downcast(StmtMutator::VisitStmt_(op)); // Remove buffers allocated inside block to detect its access region for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->erase(alloc_buffer->data); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->erase(target_buffer->data); + } // Get access detection mask // 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write int mask = 0; @@ -85,13 +88,6 @@ class ScriptCompleter : public StmtMutator { } // ignore root block or blocks which already has reads/writes regions if (mask != 0) { - if (op->iter_vars.empty()) { - // non-root opaque block is not allowed - CHECK(is_root_block) - << "ValueError: Can not auto detect buffer access region for an opaque block. Please " - "annotate the access region manually."; - return std::move(block); - } auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); const Array& reads = access_region[0]; const Array& writes = access_region[1]; @@ -122,7 +118,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { } bool contain_root = root_allocates.empty() && func->body->IsInstance() && Downcast(func->body)->block->iter_vars.empty(); - ScriptCompleter script_completer(&buffer_var_map, contain_root); + ScriptCompleter script_completer(&buffer_var_map); // generate surrounding loops automatically Stmt res = script_completer(func->body); // generate root block automatically diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index cb06df8b7655..d59c94dc5753 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -20,6 +20,7 @@ /*! * \file tvm/tir/stmt.cc */ +#include #include #include #include @@ -658,6 +659,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // BufferRegion BufferRegion::BufferRegion(Buffer buffer, Array region) { + CHECK_EQ(buffer->shape.size(), region.size()) + << "The dimension between " << buffer << " and region " << region + << " mismatched, the buffer is " << buffer; ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->region = std::move(region); @@ -705,6 +709,49 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // MatchBufferRegion MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { + const Buffer& source_buffer = source->buffer; + arith::Analyzer analyzer; + // Check scope and dtype + CHECK_EQ(buffer.scope(), source_buffer.scope()) + << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. " + << source_buffer.scope(); + CHECK_EQ(buffer->dtype, source_buffer->dtype) + << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. " + << source_buffer->dtype; + + // Check data_alignment + CHECK(source_buffer->data_alignment % buffer->data_alignment == 0) + << "Trying to match buffer to another one with lower alignment requirement " + << " required_alignment=" << buffer->data_alignment + << ", provided_alignment=" << source_buffer->data_alignment; + + // Check BufferType. AutoBroadcast is not allowed for now. + CHECK(buffer->buffer_type == BufferType::kDefault && + source_buffer->buffer_type == BufferType::kDefault) + << "AutoBroadcast is not allowed in MatchBuffer"; + + // Validate shape + CHECK(source->region.size() >= buffer->shape.size()) + << "Dimension of source Region expected to be larger or equal than target buffer shape, but " + "got " + << source->region.size() << " vs. " << buffer->shape.size(); + size_t offset = source->region.size() - buffer->shape.size(); + for (size_t i = 0; i < offset; ++i) { + CHECK(analyzer.CanProve(source->region[i]->extent == 1)) + << "The higher dimension should be 1, but got " << source->region[i]->extent << "."; + } + for (size_t i = 0; i < buffer->shape.size(); ++i) { + const Range& source_range = source->region[i + offset]; + const PrimExpr& buffer_shape = buffer->shape[i]; + if (!buffer_shape->IsInstance()) { + CHECK(analyzer.CanProve(source_range->extent == buffer_shape)) + << "The dimension mismatched between source region and target buffer shape, got " + << source_range->extent << " vs. " << buffer_shape << "."; + } + } + // Note that we do not check elem_offset and strides in this function + + // Construction ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->source = std::move(source); @@ -721,7 +768,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << op->buffer->name << " = match_buffer_region("; + p->stream << op->buffer->name << " = match_buffer("; p->Print(op->source); p->stream << ")\n"; }); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index f69a9e54afa4..bd1fa9bce836 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -362,6 +362,7 @@ class BufferCompactor : public StmtExprMutator { BlockNode* n = block.CopyOnWrite(); RewriteBufferRegions(&n->reads); RewriteBufferRegions(&n->writes); + RewriteMatchBuffers(&n->match_buffers); n->alloc_buffers = std::move(alloc_buffers); return std::move(block); } @@ -434,6 +435,18 @@ class BufferCompactor : public StmtExprMutator { *regions = std::move(new_regions); } + void RewriteMatchBuffers(Array* match_buffers) const { + Array result; + result.reserve(match_buffers->size()); + for (const auto& match_buffer : *match_buffers) { + const BufferRegion& buffer_region = match_buffer->source; + auto p = make_object(*buffer_region.get()); + RewriteBufferRegion(&p->buffer, &p->region); + result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p))); + } + *match_buffers = std::move(result); + } + /*! \brief The allocation information about each buffer. */ std::unordered_map buffer_info_; }; diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index b7348fe09fe2..7248bd4e663f 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -23,6 +23,7 @@ */ #include "ir_utils.h" +#include #include #include @@ -197,5 +198,51 @@ String GetPtrStorageScope(Var buffer_var) { return ptr_type->storage_scope; } +Array ConvertIndices(const MatchBufferRegion& match_buffer, + const Array& indices) { + const Buffer& target = match_buffer->buffer; + const BufferRegion& source = match_buffer->source; + ICHECK_EQ(indices.size(), target->shape.size()); + + arith::Analyzer analyzer; + Array result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - indices.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& range = source->region[i]; + ICHECK(analyzer.CanProve(range->extent == 1)); + result.push_back(range->min); + } + for (size_t i = 0; i < indices.size(); ++i) { + const Range& range = source->region[i + offset]; + const PrimExpr& index = indices[i]; + result.push_back(range->min + index); + } + return result; +} + +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) { + const Buffer& target = match_buffer->buffer; + const BufferRegion& source = match_buffer->source; + ICHECK_EQ(region.size(), target->shape.size()); + + arith::Analyzer analyzer; + Region result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - region.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& source_range = source->region[i]; + ICHECK(analyzer.CanProve(source_range->extent == 1)); + result.push_back(Range::FromMinExtent(source_range->min, 1)); + } + for (size_t i = 0; i < region.size(); ++i) { + const Range& source_range = source->region[i + offset]; + const Range& target_range = region[i]; + result.push_back( + Range::FromMinExtent(source_range->min + target_range->min, target_range->extent)); + } + return result; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index b5a154b707af..79c5f0609243 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -197,6 +197,22 @@ Stmt ConvertSSA(Stmt stmt); * \return A string representing the storage scope of this buffer variable. */ String GetPtrStorageScope(Var buffer_var); + +/*! + * \brief Convert match buffer target buffer access indices to original one. + * \param indices The indices of the target buffer + * \return The indices of source buffer. + */ +Array ConvertIndices(const MatchBufferRegion& match_buffer, + const Array& indices); + +/*! + * \brief Convert match buffer target buffer region to original one. + * \param region The sub-region of the target buffer + * \return The region of source buffer. + */ +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc new file mode 100644 index 000000000000..2f8fbe0ea6e7 --- /dev/null +++ b/src/tir/transforms/lower_match_buffer.cc @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_match_buffer.cc + * \brief The pass for lowering match_buffer. + */ + +#include +#include +#include +#include +#include + +#include "../ir/functor_common.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { +class MatchBufferLower : public StmtExprMutator { + public: + explicit MatchBufferLower(const PrimFunc& func) { + for (const Var& param : func->params) { + // Mark input var as const variable. + if (!param.dtype().is_handle()) var_map_.Set(param, param); + } + } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + CheckAndUpdateVarMap(match_buffer); + } + + Stmt stmt = StmtExprMutator ::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + Array reads = MutateArray( + op->reads, std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); + Array writes = MutateArray( + op->writes, std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); + + if (reads.same_as(op->reads) && writes.same_as(op->writes) && op->match_buffers.empty()) { + return stmt; + } else { + auto n = CopyOnWrite(op); + n->match_buffers = {}; + n->reads = std::move(reads); + n->writes = std::move(writes); + return Stmt(n); + } + } + + Stmt VisitStmt_(const ForNode* op) final { + analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + auto it = var_map_.find(v); + if (it != var_map_.end()) { + return (*it).second; + } else { + return std::move(v); + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + + auto it = match_buffers_.find(op->buffer); + if (it == match_buffers_.end()) { + return stmt; + } else { + const Buffer& buffer = (*it).first; + const BufferRegion& source = (*it).second; + + auto n = CopyOnWrite(op); + n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + n->buffer = source->buffer; + return Stmt(n); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op != nullptr); + + auto it = match_buffers_.find(op->buffer); + if (it == match_buffers_.end()) { + return expr; + } else { + const Buffer& buffer = (*it).first; + const BufferRegion& source = (*it).second; + Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + return BufferLoad(source->buffer, indices); + } + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + CHECK(var_map_.find(op->buffer_var) == var_map_.end()) + << "Load from buffer created by match_buffer is not allowed, but got: " << expr; + return expr; + } + + Stmt VisitStmt_(const StoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + CHECK(var_map_.find(op->buffer_var) == var_map_.end()) + << "Store from buffer created by match_buffer is not allowed, but got: " << stmt; + return stmt; + } + + BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + auto it = match_buffers_.find(buffer); + if (it == match_buffers_.end()) { + return buffer_region; + } else { + const BufferRegion& source = (*it).second; + Region region = ConvertRegion(MatchBufferRegion(buffer, source), buffer_region->region); + return BufferRegion(source->buffer, std::move(region)); + } + } + + private: + void CheckAndUpdateVarMap(const MatchBufferRegion& match_buffer) { + // Step.1. Check + const Buffer& buffer = match_buffer->buffer; + const BufferRegion& source = VisitBufferRegion(match_buffer->source); + const Buffer& source_buffer = source->buffer; + + // Step.1.1. Check scope & dtype + ICHECK_EQ(buffer.scope(), source_buffer.scope()) + << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << "vs." + << source_buffer.scope(); + ICHECK_EQ(buffer->dtype, source_buffer->dtype) + << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << "vs." + << source_buffer->dtype; + + // Step.1.2. Check data alignment + if (source_buffer->data_alignment % buffer->data_alignment != 0) { + LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " + << " required_alignment=" << buffer->data_alignment + << ", provided_alignment=" << source_buffer->data_alignment; + } + if (is_zero(buffer->elem_offset)) { + ICHECK(is_zero(source_buffer->elem_offset)) + << "Trying to bind a Buffer with offset into one without offset " + << " required elem_offset=" << buffer->elem_offset + << ", provided elem_offset=" << source_buffer->elem_offset; + } + + // Step.2. Update + match_buffers_.Set(buffer, source); + // Step.2.1. Update buffer data + Bind(buffer->data, source_buffer->data, buffer->name + ".data"); + + // Step.2.2. Update element offset + // Note we create Load via vload and try to reuse index calculate. + { + Array indices; + indices.reserve(source->region.size()); + for (const Range& range : source->region) { + indices.push_back(range->min); + } + + Load load = Downcast(source_buffer.vload(indices, source_buffer->dtype)); + Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset"); + CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) + << "The source elem_offset " << buffer->elem_offset + << " does not satisfy the offset_factor " << buffer->offset_factor << "."; + } + + // Step 2.3. Check and update strides + // Check if target buffer strides are defined + if (!buffer->strides.empty()) { + ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); + PrimExpr stride = make_const(DataType::Int(32), 1); + for (size_t i = buffer->shape.size(); i > 0; --i) { + const PrimExpr& shape = source_buffer->shape[i - 1]; + Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1)); + stride *= shape; + } + } + + // Step 2.4. Check and update shape + ICHECK(source->region.size() >= buffer->shape.size()); + size_t offset = source->region.size() - buffer->shape.size(); + for (size_t i = 0; i < buffer->shape.size(); ++i) { + const Range& range = source->region[i + offset]; + Bind(buffer->shape[i], range->extent, buffer->name + ".shape_" + std::to_string(i)); + } + } + + void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") { + CHECK_EQ(arg.dtype(), value.dtype()) + << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype; + // Handle recursive case + value = Substitute(std::move(value), var_map_); + if (arg->IsInstance()) { + Var v = Downcast(arg); + auto it = var_map_.find(v); + if (it == var_map_.end()) { + var_map_.Set(v, value); + analyzer_.Bind(v, value); + } else { + AssertBinding((*it).second, value, arg_name); + } + } else { + AssertBinding(arg, value, arg_name); + } + } + + void AssertBinding(const PrimExpr& lhs, const PrimExpr& rhs, + const std::string& arg_name = "argument") { + CHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name + << " unmet: " << lhs << "==" << rhs << "."; + } + + private: + /*! \brief Buffer region mapping. */ + Map match_buffers_; + /*! \brief Var mapping for buffer signature (data, strides, element_offset, etc.) */ + Map var_map_; + /*! \brief The analyzer */ + arith::Analyzer analyzer_; +}; + +PrimFunc LowerMatchBuffer(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + fptr->body = MatchBufferLower(func)(std::move(fptr->body)); + return func; +} + +namespace transform { + +Pass LowerMatchBuffer() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + return LowerMatchBuffer(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerMatchBuffer", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchBuffer); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py new file mode 100644 index 000000000000..3fa4795870d5 --- /dev/null +++ b/tests/python/integration/test_lower.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument +"""Test workload for lowering and build""" +import tvm +from tvm import tir +from tvm.script import ty +import tvm.testing +import numpy as np + + +@tvm.script.tir +def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + # match buffer + A = tir.match_buffer(a, [1024, 1024], "float16") + B = tir.match_buffer(b, [1024, 1024], "float16") + C = tir.match_buffer(c, [1024, 1024], "float32") + + # body + for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"): + for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"): + with tir.block([16, 8]) as [bx, by]: + tir.bind(bx, blockIdx_x) + tir.bind(by, blockIdx_y) + shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") + wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") + wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") + for ty in tir.thread_binding(0, 2, "threadIdx.y"): + for tz in tir.thread_binding(0, 2, "threadIdx.z"): + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads([]) + tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + C0 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_fill_fragment( + C0.data, + 16, + 16, + 16, + i * 4 + j, + tir.float32(0), + dtype="handle", + ) + ) + + for ko in range(0, 32): + # copy data from global to shared + for tx in tir.thread_binding(0, 32, "threadIdx.x"): + for i0, j0 in tir.grid(1, 4): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, bx * 64 + ty * 32 + tx + i0) + tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + shared_A[vi, vj + 8] = A[vi, vj] + + for i0, j0 in tir.grid(2, 4): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) + tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + shared_B[vi, vj + 8] = B[vi, vj] + + for ki in range(0, 2): + for i in range(0, 2): + with tir.block([64, 64]) as [vi, vk]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ] + ) + tir.writes( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + A0 = tir.match_buffer( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ], + (16, 16 + 8), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_A0 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_a", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + wmma_A0.data, + 16, + 16, + 16, + i, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset + 8, + A0.strides[0], + 1, + dtype="handle", + ), + A0.strides[0], + "row_major", + dtype="handle", + ) + ) + for j in range(0, 4): + with tir.block([64, 64]) as [vj, vk]: + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ] + ) + tir.writes( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + B0 = tir.match_buffer( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ], + (16, 16 + 8), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_B0 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_b", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + wmma_B0.data, + 16, + 16, + 16, + j, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + B0.data, + B0.elem_offset + 8, + B0.strides[0], + 1, + dtype="handle", + ), + B0.strides[0], + "col_major", + dtype="handle", + ) + ) + for i, j in tir.grid(2, 4): + with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [ + vi, + vj, + vk, + ]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + [ + wmma_A[ + vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_B[ + vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_C[ + vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 + ], + ] + ) + tir.writes( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] + ) + wmma_A1 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_a", + offset_factor=1, + ) + wmma_B1 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_b", + offset_factor=1, + ) + wmma_C1 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_mma_sync( + wmma_C1.data, + i * 4 + j, + wmma_A1.data, + i, + wmma_B1.data, + j, + wmma_C1.data, + i * 4 + j, + dtype="handle", + ) + ) + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + s0 = tir.var("int32") + s1 = tir.var("int32") + wmma_C2 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + C1 = tir.match_buffer( + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[s0, s1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_store_matrix_sync( + wmma_C2.data, + 16, + 16, + 16, + i * 4 + j, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float32"), + C1.data, + C1.elem_offset, + C1.strides[0], + 1, + dtype="handle", + ), + C1.strides[0], + "row_major", + dtype="handle", + ) + ) + + +@tvm.testing.requires_cuda +def test_gemm_tensorcore(): + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(1024, 1024)).astype("float16") + b_np = np.random.uniform(size=(1024, 1024)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.T.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev) + f = tvm.build(tensorcore_gemm, target="cuda", name="dense") + f(a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + evaluator = f.time_evaluator(f.entry_name, dev, number=100) + t = evaluator(a, b, c).mean + num_flops = 2 * 1024 * 1024 * 1024 + gflops = num_flops / (t * 1e3) / 1e6 + print("gemm with tensor core: %f ms" % (t * 1e3)) + print("GFLOPS: %f" % gflops) + + +if __name__ == "__main__": + test_gemm_tensorcore() diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 36fd80fd07de..8c2b2710f1ba 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -70,6 +70,22 @@ def lca_is_func_root(a: ty.handle) -> None: A.data[0] = 1.0 +@tvm.script.tir +def match_buffer_func(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.match_buffer(b, (128, 128), "float32") + with tir.block([8, 8], "block") as [vi, vj]: + tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with tir.block([16, 16], "AAA") as [i, j]: + AA = tir.match_buffer(A[i, j], ()) + AA[()] = 1.0 + tir.evaluate(B0.data) + tir.evaluate(B1.data) + + def test_buffer_load_store(): func = buffer_load_store_func A, B = [func.buffer_map[x] for x in func.params] @@ -115,7 +131,24 @@ def test_lca_func_root(): assert lca[A] is None +def test_match_buffer(): + func = match_buffer_func + A, B = [func.buffer_map[x] for x in func.params] + lca = tir.analysis.detect_buffer_access_lca(func) + + root_block = func.body.block + block = root_block.body.body.body.block + block_inner = block.body[0].body.body.block + + # LCA of Buffer C is the inner block + assert lca[A] == block_inner + + # LCA of Buffer C is the main block + assert lca[B] == block + + if __name__ == "__main__": test_buffer_load_store() test_opaque_access() test_lca_func_root() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 7e4d7d87c1e1..7641f0ac46cb 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -39,6 +39,48 @@ def func() -> None: tir.evaluate(D.data) +@tvm.script.tir +def match_buffer_func() -> None: + with tir.block([], "root"): + A = tir.alloc_buffer((128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + tir.reads([]) + tir.writes([]) + # Need add read/write region manually to avoid triggering block access region detector + with tir.block([8, 8], "block") as [vi, vj]: + tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + AA = tir.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) + B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with tir.block([16, 16], "AAA") as [i, j]: + tir.reads([]) + tir.writes(AA[i, j]) + AAA = tir.match_buffer(AA[i, j], ()) + AAA[()] = 1.0 + tir.evaluate(B0.data) + tir.evaluate(B1.data) + + +@tvm.script.tir +def opaque_block_func() -> None: + with tir.block([], "root"): + A = tir.alloc_buffer((16, 16), "float32") + B = tir.alloc_buffer((16, 16), "float32") + tir.reads([]) + tir.writes([]) + # Need add read/write region manually to avoid triggering block access region detector + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes([B[i, 0:16]]) + for j in range(0, 16): + with tir.block([]): + tir.reads(A[i, j]) + tir.writes(B[i, j]) + B[i, j] = A[i, j] + 1.0 + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -53,5 +95,41 @@ def test_block_access_region_detector(): ) +def test_opaque_block(): + alloc_buffers = opaque_block_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + block0 = opaque_block_func.body.block.body.body.block + ret = tir.analysis.get_block_access_region(block0, buffer_var_map) + tvm.ir.assert_structural_equal(block0.reads, ret[0]) + tvm.ir.assert_structural_equal(block0.writes, ret[1]) + + block1 = block0.body.body.block + ret = tir.analysis.get_block_access_region(block1, buffer_var_map) + tvm.ir.assert_structural_equal(block1.reads, ret[0]) + tvm.ir.assert_structural_equal(block1.writes, ret[1]) + + +def test_match_buffer(): + root_block = match_buffer_func.body.block + block = root_block.body.body.body.block + block_inner = block.body[0].body.body.block + alloc_buffers = func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + # Check inner block AAA + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) + tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) + + # Check block + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(block.writes, ret[1]) + # B is opaque access + tvm.ir.assert_structural_equal(block.reads, ret[2]) + + if __name__ == "__main__": test_block_access_region_detector() + test_opaque_block() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py new file mode 100644 index 000000000000..78a8c5117849 --- /dev/null +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -0,0 +1,455 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm import tir +from tvm.script import ty + + +def _check(original, transformed): + mod = tvm.IRModule.from_expr(original) + mod = tvm.tir.transform.LowerMatchBuffer()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed) + + +def _check_fail(original): + mod = tvm.IRModule.from_expr(original) + with pytest.raises(tvm.TVMError): + mod = tvm.tir.transform.LowerMatchBuffer()(mod) + + +@tvm.script.tir +def buffer_load_store(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i, j, k in tir.grid(4, 16, 8): + with tir.block([]): + tir.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) + tir.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) + sub_A = tir.match_buffer( + A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2], (4, 1, 2), offset_factor=1 + ) + sub_C = tir.match_buffer( + C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2], (4, 2), offset_factor=1 + ) + for ii, kk in tir.grid(4, 2): + sub_A[ii, 0, kk] += sub_C[ii, kk] + + +@tvm.script.tir +def transformed_buffer_load_store(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i, j, k in tir.grid(4, 16, 8): + with tir.block([]): + tir.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) + tir.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) + for ii, kk in tir.grid(4, 2): + A[i * 4 + ii, j, k * 2 + kk] += C[i * 4 + ii, k * 2 + kk] + + +@tvm.ir.register_op_attr("tir.intrin_test", "") +def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1): + return 0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (32, 64, 128)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(2, 64, 8): + with tir.block([]): + tir.reads([]) + tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) + sub_A = tir.match_buffer( + A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16], + (16, 1, 16), + strides=[8192, 128, 1], + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_A.data, + sub_A.elem_offset, + sub_A.strides[0], + sub_A.strides[1], + sub_A.shape[0], + sub_A.shape[1], + dtype="handle", + ) + ) + for i, j, k in tir.grid(64, 2, 8): + with tir.block([]): + Bs_0 = tir.var("int32") + Bs_1 = tir.var("int32") + tir.reads([]) + tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) + sub_B = tir.match_buffer( + B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8], + (32, 8), + strides=[Bs_0, Bs_1], + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_B.data, + sub_B.elem_offset, + sub_B.strides[0], + sub_B.strides[1], + sub_B.shape[0], + sub_B.shape[1], + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (32, 64, 128)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(2, 64, 8): + with tir.block([]): + tir.reads([]) + tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) + tir.evaluate( + tir.intrin_test( + A.data, + i * 131072 + j * 128 + k * 16, + 8192, + 128, + 16, + 1, + dtype="handle", + ) + ) + for i, j, k in tir.grid(64, 2, 8): + with tir.block([]): + tir.reads([]) + tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) + tir.evaluate( + tir.intrin_test( + B.data, + i * 4096 + j * 2048 + k * 8, + 64, + 1, + 32, + 8, + dtype="handle", + ) + ) + + +@tvm.script.tir +def recursive_match(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (64, 64, 64)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(64, 4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + ] + ) + As_0 = tir.var("int32") + As_1 = tir.var("int32") + sub_A = tir.match_buffer( + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + (16, 16), + strides=[As_0, As_1], + offset_factor=1, + ) + sub_B = tir.match_buffer( + B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + (16, 16), + offset_factor=1, + ) + for jj, kk in tir.grid(4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + ] + ) + Ass_0 = tir.var("int32") + Ass_1 = tir.var("int32") + sub_sub_A = tir.match_buffer( + sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + (4, 4), + strides=[Ass_0, Ass_1], + offset_factor=1, + ) + sub_sub_B = tir.match_buffer( + sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + (4, 4), + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_sub_A.data, + sub_sub_A.elem_offset, + sub_sub_A.strides[0], + sub_sub_A.strides[1], + sub_sub_A.shape[0], + sub_sub_A.shape[1], + dtype="handle", + ) + ) + for jjj, kkk in tir.grid(4, 4): + sub_sub_B[jjj, kkk] = 1 + + +@tvm.script.tir +def transformed_recursive_match(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (64, 64, 64)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(64, 4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + ] + ) + for jj, kk in tir.grid(4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + A[ + i, + j * 16 + jj * 4 : j * 16 + jj * 4 + 4, + k * 16 + kk * 4 : k * 16 + kk * 4 + 4, + ], + B[ + i, + j * 16 + jj * 4 : j * 16 + jj * 4 + 4, + k * 16 + kk * 4 : k * 16 + kk * 4 + 4, + ], + ] + ) + tir.evaluate( + tir.intrin_test( + A.data, + i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4, + 64, + 1, + 4, + 4, + dtype="handle", + ) + ) + for jjj, kkk in tir.grid(4, 4): + B[i, j * 16 + jj * 4 + jjj, k * 16 + kk * 4 + kkk] = 1 + + +@tvm.script.tir +def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: + A = tir.match_buffer(a, (n * m, m)) + B = tir.match_buffer(b, (n * 2, m * 4)) + for i in range(0, n): + with tir.block([]): + tir.reads([]) + tir.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) + Bs_0 = tir.var("int32") + Bs_1 = tir.var("int32") + sub_A = tir.match_buffer(A[i * m : i * m + m, 0:m], (m, m), offset_factor=1) + sub_B = tir.match_buffer( + B[i * n : i * n + 2, 0 : m * 4], (2, m * 4), strides=[Bs_0, Bs_1], offset_factor=1 + ) + for ii, jj in tir.grid(m, m): + sub_A[ii, jj] = 1 + for j in range(0, 4): + tir.evaluate( + tir.intrin_test( + sub_B.data, + sub_B.elem_offset, + sub_B.strides[0], + sub_B.strides[1], + sub_B.shape[0], + sub_B.shape[1], + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: + A = tir.match_buffer(a, (n * m, m)) + B = tir.match_buffer(b, (n * 2, m * 4)) + for i in range(0, n): + with tir.block([]): + tir.reads([]) + tir.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) + for ii, jj in tir.grid(m, m): + A[i * m + ii, jj] = 1 + for j in range(0, 4): + tir.evaluate( + tir.intrin_test( + B.data, + i * n * (m * 4), + m * 4, + 1, + 2, + m * 4, + dtype="handle", + ) + ) + + +@tvm.script.tir +def rank0_buffer(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + B = tir.match_buffer(b, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads([]) + tir.writes([A[i, j], B[i, j]]) + sub_A = tir.match_buffer(A[i, j], (), offset_factor=1) + sub_B = tir.match_buffer(B[i, j], (), offset_factor=1) + sub_A[()] = 1 + tir.evaluate( + tir.intrin_test( + sub_B.data, + sub_B.elem_offset, + 0, + 0, + 0, + 0, + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_rank0_buffer(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + B = tir.match_buffer(b, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads([]) + tir.writes([A[i, j], B[i, j]]) + A[i, j] = 1 + tir.evaluate( + tir.intrin_test( + B.data, + i * 8 + j, + 0, + 0, + 0, + 0, + dtype="handle", + ) + ) + + +@tvm.script.tir +def fail_match_load(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads(A[i, j]) + tir.writes([]) + sub_A = tir.match_buffer(A[i, j], ()) + tir.evaluate(tir.load("float32", sub_A.data, 0)) + + +@tvm.script.tir +def fail_match_store(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, j]) + sub_A = tir.match_buffer(A[i, j], ()) + sub_A.data[0] = 1 + + +@tvm.script.tir +def fail_buffer_bind(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): + stride = tir.var("int32") + sub_A = tir.match_buffer( + A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 + ) + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 + + +@tvm.script.tir +def fail_match_func_param(a: ty.handle, m: ty.handle, n: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): + sub_A = tir.match_buffer( + A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1 + ) + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 + + +def test_buffer_load_store(): + _check(buffer_load_store, transformed_buffer_load_store) + + +def test_opaque_access(): + _check(opaque_access, transformed_opaque_access) + + +def test_recursive_match(): + _check(recursive_match, transformed_recursive_match) + + +def test_symbolic_match(): + _check(symbolic_match, transformed_symbolic_match) + + +def test_rank0_buffer(): + _check(rank0_buffer, transformed_rank0_buffer) + + +def test_fail_load_store(): + _check_fail(fail_match_load) + _check_fail(fail_match_store) + + +def test_fail_buffer_bind(): + _check_fail(fail_buffer_bind) + + +def test_fail_match_func_param(): + _check_fail(fail_match_func_param) + + +if __name__ == "__main__": + test_buffer_load_store() + test_opaque_access() + test_recursive_match() + test_symbolic_match() + test_rank0_buffer() + test_fail_load_store() + test_fail_buffer_bind() + test_fail_match_func_param() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 07a82ba9936c..dbae0b6fa516 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -398,7 +398,7 @@ def test_block_blockrealize(): ) ] writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])] - match_buffer_region = tvm.tir.MatchBufferRegion( + block_match_buffer = tvm.tir.MatchBufferRegion( match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)]) ) @@ -410,7 +410,7 @@ def test_block_blockrealize(): body, init=init_body, alloc_buffers=[alloc_buffer], - match_buffers=[match_buffer_region], + match_buffers=[block_match_buffer], annotations={"attr_key": "attr_value"}, ) @@ -462,7 +462,7 @@ def test_block_blockrealize(): assert output.find("reads") != -1 assert output.find("writes") != -1 assert output.find("alloc_buffer") != -1 - assert output.find("match_buffer_region") != -1 + assert output.find("match_buffer") != -1 assert output.find("attr") != -1 assert output.find("with init()") != -1 @@ -471,7 +471,6 @@ def test_block_blockrealize(): test_intimm_cond() test_buffer_load_store() test_vars() - test_scoped_storage_var() test_prim_func() test_cast() test_attr() diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index c34ec8d610d6..0a33db09aef1 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -171,7 +171,7 @@ def buffer_matched(a: ty.handle, c: ty.handle) -> None: with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: - Bb = tir.match_buffer_region(B[vi : vi + 1, vj]) + Bb = tir.match_buffer(B[vi : vi + 1, vj], (1, 1)) C[vi, vj] = Bb[0, 0] + 1.0 diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 7c06b5ef5ca1..a469c6d0cc13 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -293,6 +293,52 @@ def compacted_complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: C[i, j] = B[0, j] +@tvm.script.tir +def match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((16, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[i, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[i, j], ()) + C1[()] = B2[()] * 2.0 + + +@tvm.script.tir +def compacted_match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((1, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[0, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[0, j], ()) + C1[()] = B2[()] * 2.0 + + def test_elementwise(): _check(elementwise_func, compacted_elementwise_func) @@ -321,6 +367,10 @@ def test_complex(): _check(complex_func, compacted_complex_func) +def test_match_buffer(): + _check(match_buffer_func, compacted_match_buffer_func) + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() @@ -329,3 +379,4 @@ def test_complex(): test_warp_mem() test_symbolic() test_complex() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index 3fb8331d39fc..badf5e0e4d10 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -18,6 +18,8 @@ from tvm import tir from tvm.script import ty +# pylint: disable=no-self-argument + @tvm.script.tir class WithInit: @@ -43,11 +45,46 @@ def main(a: ty.handle, b: ty.handle) -> None: B[i] += A[i, j, k] +@tvm.script.tir +class InitWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + with tir.init(): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + +@tvm.script.tir +class BranchWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + if (j == 0) and (k == 32): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + def test_lower_reduction(): origin_mod = WithInit() mod = tvm.tir.transform.LowerInitBlock()(origin_mod) tvm.ir.assert_structural_equal(mod, WithBranch(), True) +def test_lower_match_buffer(): + origin_mod = InitWithMatchBuffer() + mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer(), True) + + if __name__ == "__main__": test_lower_reduction() + test_lower_match_buffer() diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index d42c5e1f8626..022c964df0c7 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -115,6 +115,28 @@ def transformed_func() -> None: ) +@tvm.script.tir +def match_buffer_func() -> None: + C = tir.alloc_buffer((128, 128)) + with tir.block([128]) as [vi]: + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + +@tvm.script.tir +def transformed_match_buffer_func() -> None: + for i in range(0, 128): + with tir.block([128]) as [vi]: + tir.bind(vi, i) + C = tir.alloc_buffer((128, 128)) + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + def test_elementwise(): _check(element_func, transformed_element_func) @@ -123,6 +145,11 @@ def test_locate_buffer_allocation(): _check(original_func, transformed_func) +def test_match_buffer_allocation(): + _check(match_buffer_func, transformed_match_buffer_func) + + if __name__ == "__main__": test_elementwise() test_locate_buffer_allocation() + test_match_buffer_allocation() diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index a4d2dec0cce9..4798e9e09865 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -177,19 +177,6 @@ def test_complete_part_region(): _check_elementwise(func_with_part_access_region) -def test_complete_opaque_block_error(): - def render(e): - pass - - override_renderer(render) - - try: - from_source(func_with_opaque_block) - except tvm.error.DiagnosticError: - return - assert False - - @tvm.script.tir def func_with_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: data_buf = tir.match_buffer(data, (16, 16), "float32") @@ -255,10 +242,46 @@ def test_complete_buffer_indices(): tvm.ir.assert_structural_equal(new_func, expected_recursive_bufferslice_indices) +@tvm.script.tir +def match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +@tvm.script.tir +def expected_match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, 0:16]) + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + tir.reads([]) + tir.writes(A0[0:16]) + for j in range(0, 16): + with tir.block([]) as []: + tir.reads([]) + tir.writes(A0[j]) + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +def test_complete_match_buffer(): + tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func) + + if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() test_complete_with_root() - test_complete_opaque_block_error() test_complete_part_region() test_complete_buffer_indices() + test_complete_match_buffer() diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index a72b13e38829..7aeceeccfa89 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -202,7 +202,7 @@ def test_inconsistent_grid(): def invalid_match_buffer_region() -> None: with tir.block([16, 16]) as [vi, vj]: - A = tir.match_buffer_region(vi) # error + A = tir.match_buffer(vi) # error tir.evaluate(1.0) @@ -363,6 +363,23 @@ def test_tvm_exception_catch(): check_error(intrin_except_assign, 3) +def buffer_shape_mismatch(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): + tir.reads([]) + tir.writes([A[i, j * 4 : j * 4 + 4]]) + sub_A = tir.match_buffer( + A[i, j * 4 : j * 4 + 4], (5) + ) # error: shape mismatched between 4 and 5 + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 + + +def test_match_buffer_shape_mismatch(): + check_error(buffer_shape_mismatch, 7) + + def check_error(module, rel_lineno): # Override the default renderer to accumulate errors _, start_line = inspect.getsourcelines(module) @@ -414,3 +431,4 @@ def render(e): test_error_index_with_stop_slice() test_mismatch_args() test_tvm_exception_catch() + test_match_buffer_shape_mismatch() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index c688701d2cca..0566ff5044d9 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2820,6 +2820,43 @@ def test_for_thread_binding(): assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y" +@tvm.script.tir +def match_buffer_region(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16), "float32") + B = tir.match_buffer(b, (1), "float32") + + with tir.block([16, 4]) as [vi, vj]: + C = tir.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) + with tir.block([4]) as [vii]: + D = tir.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) + for i, j in tir.grid(4, 4): + B[0] += D[i, 0, j] + + +def test_match_buffer_region(): + func = match_buffer_region + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + assert isinstance(rt_func.body, tir.stmt.BlockRealize) + root = rt_func.body.block + + assert isinstance(root.body, tir.stmt.For) + assert isinstance(root.body.body, tir.stmt.For) + assert isinstance(root.body.body.body, tir.stmt.BlockRealize) + outer_block = root.body.body.body.block + assert len(outer_block.match_buffers) == 1 + buffer_C = outer_block.match_buffers[0].buffer + tvm.ir.assert_structural_equal(buffer_C.shape, [16, 1, 4]) + + assert isinstance(outer_block.body, tir.stmt.For) + assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) + inner_block = outer_block.body.body.block + assert len(inner_block.match_buffers) == 1 + buffer_D = inner_block.match_buffers[0].buffer + tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) + + @tvm.script.tir def block_elements(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (16, 16), "float32") @@ -2832,10 +2869,10 @@ def block_elements(a: ty.handle, b: ty.handle) -> None: tir.writes(B[0, 0]) tir.block_attr({"attr_key": "attr_value"}) C = tir.alloc_buffer((4, 4), dtype="float32") - D = tir.match_buffer_region(A[0:4, 0]) + D = tir.match_buffer(A[0:4, 0], (4, 1)) with tir.init(): B[0, 0] = tir.float32(0) - B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2, 0] + B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2] def test_block_elements(): @@ -2988,6 +3025,7 @@ def test_script_printer(): test_element_wise() test_predicate() test_for_thread_binding() + test_match_buffer_region() test_block_elements() test_opaque_block() test_abs()