diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 41fa6a5fa2f7..1d2550eecde2 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -112,9 +112,9 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None) scope = tvm.runtime.convert(scope) return tvm.tir.Allocate( - self.buffer.data, - self.buffer.dtype, - self.buffer.shape, + self.buffer_var, + dtype, + extents, condition, self.body, annotations=annotations, @@ -122,7 +122,7 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None) ) super().__init__(allocate, concise_scope=True, def_symbol=True) - self.buffer = None + self.buffer_var = None def enter_scope( self, @@ -146,20 +146,15 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer( + def setup_buffer_var( extents, dtype, scope, condition=True, annotations=None, span: Span = None ): - """Setup buffer object for a given type.""" - self.buffer = tvm.tir.decl_buffer( - shape=extents, - dtype=dtype, - name=name, - scope=scope, - span=span, - ) + """Setup buffer var for a given type.""" + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) + self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) - setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer, node) + setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer_var, node) @register @@ -176,7 +171,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None): list_data.append(i.value) nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype)) n = tvm.tir.AllocateConst( - self.buffer.data, + self.buffer_var, dtype, shape, nd_data, @@ -187,7 +182,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None): return n super().__init__(allocate_const, concise_scope=True, def_symbol=True) - self.buffer = None + self.buffer_var = None def enter_scope( self, @@ -211,17 +206,13 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None): + def setup_buffer_var(data, dtype, shape, annotations: dict = None, span: Span = None): """Setup buffer var for a given type.""" - self.buffer = tvm.tir.decl_buffer( - shape=shape, - dtype=dtype, - name=name, - span=span, - ) + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) + self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) - setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer, node) + setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer_var, node) @register @@ -248,7 +239,18 @@ def decl_buffer( axis_separators=None, span=None, ): - return tvm.tir.DeclBuffer(self.buffer, self.body, span=span) + decl_buffer = tvm.tir.DeclBuffer(self.buffer, self.body, span=span) + if data is None: + # when data is not specified, the buffer is implicitly allocated + return tvm.tir.Allocate( + self.buffer.data, + dtype, + shape, + tvm.runtime.convert(True), + decl_buffer, + span=span, + ) + return decl_buffer super().__init__(decl_buffer, concise_scope=True, def_symbol=True) @@ -298,6 +300,7 @@ def setup_buffer( offset_factor=offset_factor, buffer_type=buffer_type, axis_separators=axis_separators, + name=name, span=span, ) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index f5300e1e6985..5da81de4dc5d 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -100,6 +100,12 @@ class BufferUsageFinder : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const DeclBufferNode* op) final { + buffers_declared_.insert(op->buffer.get()); + StmtExprVisitor::VisitStmt_(op); + buffers_declared_.erase(op->buffer.get()); + } + private: explicit BufferUsageFinder(Map> usage) : usage_(usage) {} @@ -107,6 +113,9 @@ class BufferUsageFinder : public StmtExprVisitor { if (buffers_visited_.count(buffer.get())) { return; } + if (buffers_declared_.count(buffer.get())) { + return; + } buffers_visited_.insert(buffer.get()); Array arr = usage_.Get(buffer->data).value_or({}); @@ -119,6 +128,9 @@ class BufferUsageFinder : public StmtExprVisitor { // The buffers that have been visited so far, to avoid duplicate // entries in the search result. std::unordered_set buffers_visited_; + // The buffers declared via `DeclBuffer`. These buffers are excluded from the result because + // T.buffer_decl shouldn't be printed for them. + std::unordered_set buffers_declared_; }; /*! @@ -1055,58 +1067,57 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { } namespace { -struct AllocUsage { - Buffer alloc_buffer; - Array aliasing_buffers; -}; -template -AllocUsage FindAllocateUsage(AllocNode* op, Map>* cache_ptr) { - Map>& cache = *cache_ptr; - if (!cache.count(op->buffer_var)) { - cache = BufferUsageFinder::FindUsage(std::move(cache), op->body); +bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) { + const Var& buffer_var = allocate->buffer_var; + const DeclBufferNode* decl_buffer = allocate->body.as(); + if (!decl_buffer) { + return false; } - Array buffer_usage = cache.Get(op->buffer_var).value_or({}); - - auto is_exact_match = [](Buffer a, Buffer b) { - if (a->dtype != b->dtype) return false; - if (a->shape.size() != b->shape.size()) return false; - - arith::Analyzer analyzer; - for (size_t i = 0; i < a->shape.size(); i++) { - if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) { - return false; - } - } - return true; - }; - - // If the buffer allocated via T.allocate is an exact match to the - // usage of the buffer later on, then that buffer is the return - // value of T.allocate, and no T.buffer_decl statement is needed. - Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0, - 0, kDefault); - bool found_alloc_buf = false; - Array aliasing_buffers; - for (const auto& buf : buffer_usage) { - if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) { - alloc_buffer = buf; - found_alloc_buf = true; - } else { - aliasing_buffers.push_back(buf); + const Buffer& buffer = decl_buffer->buffer; + if (!buffer_var.same_as(buffer->data)) { + return false; + } + if (allocate->dtype != buffer->dtype) { + return false; + } + if (!is_one(allocate->condition)) { + return false; + } + if (allocate->annotations.size()) { + return false; + } + if (allocate->extents.size() != buffer->shape.size()) { + return false; + } + tir::ExprDeepEqual expr_equal; + for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) { + if (!expr_equal(allocate->extents[i], buffer->shape[i])) { + return false; } } - - return AllocUsage{alloc_buffer, aliasing_buffers}; + return true; } + } // namespace Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { - auto usage = FindAllocateUsage(op, &buffer_var_usage_); - Buffer& alloc_buffer = usage.alloc_buffer; - Array& aliasing_buffers = usage.aliasing_buffers; - buf_not_in_headers_.insert(alloc_buffer.get()); - var_not_in_headers_.insert(alloc_buffer->data.get()); + var_not_in_headers_.insert(op->buffer_var.get()); + + if (!buffer_var_usage_.count(op->buffer_var)) { + buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body); + } + Array buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({}); + + if (buffer_usage.empty()) { + if (IsAllocateDeclBufferPattern(op)) { + // As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single + // DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to + // delegate the printing of the current node to `DeclBufferNode` while maintaining the + // same value of `current_num_` and `num_child_`. + return Print(op->body); + } + } auto storage_scope = GetPtrStorageScope(op->buffer_var); Doc func_call; @@ -1124,12 +1135,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; if (current_num_ != num_child_ - 1) { - doc << "with " << func_call << " as " << Print(alloc_buffer) << ":"; - doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers) - << PrintBody(op->body)); + doc << "with " << func_call << " as " << Print(op->buffer_var) << ":"; + doc << Doc::Indent( + 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body)); } else { - doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine(); - doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(op->body); + doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body); } TryDeallocVar(op->buffer_var); return doc; @@ -1179,11 +1190,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { } auto ndarray_str = ss.str(); - auto usage = FindAllocateUsage(alloc, &buffer_var_usage_); - Buffer& alloc_buffer = usage.alloc_buffer; - Array& aliasing_buffers = usage.aliasing_buffers; - buf_not_in_headers_.insert(alloc_buffer.get()); - var_not_in_headers_.insert(alloc_buffer->data.get()); + var_not_in_headers_.insert(alloc->buffer_var.get()); + + if (!buffer_var_usage_.count(alloc->buffer_var)) { + buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), alloc->body); + } + Array buffer_usage = buffer_var_usage_.Get(alloc->buffer_var).value_or({}); Doc func_call; func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) @@ -1192,12 +1204,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { Doc doc; var_not_in_headers_.insert(alloc->buffer_var.get()); if (current_num_ != num_child_ - 1) { - doc << "with " << func_call << " as " << Print(alloc_buffer) << ":"; - doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers) + doc << "with " << func_call << " as " << Print(alloc->buffer_var) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(alloc->body)); } else { - doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine(); - doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(alloc->body); + doc << Print(alloc->buffer_var) << " = " << func_call << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(alloc->body); } return doc; } diff --git a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py index f348fd7f5a77..8c598fe0d794 100644 --- a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py +++ b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py @@ -40,14 +40,14 @@ def main() -> None: buffer9 = T.buffer_decl([32], "uint8") buffer10 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([128], "uint8", "global") - p2 = T.allocate([112], "uint8", "global") - p3 = T.allocate([112], "uint8", "global") - p4 = T.allocate([32], "uint8", "global") - p5 = T.allocate([32], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") - p7 = T.allocate([112], "uint8", "global") - p8 = T.allocate([32], "uint8", "global") + p1 = T.decl_buffer([128], "uint8") + p2 = T.decl_buffer([112], "uint8") + p3 = T.decl_buffer([112], "uint8") + p4 = T.decl_buffer([32], "uint8") + p5 = T.decl_buffer([32], "uint8") + p6 = T.decl_buffer([32], "uint8") + p7 = T.decl_buffer([112], "uint8") + p8 = T.decl_buffer([32], "uint8") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -88,14 +88,14 @@ def main() -> None: buffer9 = T.buffer_decl([32], "uint8") buffer10 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([128], "uint8", "global") - p2 = T.allocate([112], "uint8", "global") - p3 = T.allocate([112], "uint8", "global") - p4 = T.allocate([32], "uint8", "global") - p5 = T.allocate([32], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") - p7 = T.allocate([112], "uint8", "global") - p8 = T.allocate([32], "uint8", "global") + p1 = T.decl_buffer([128], "uint8") + p2 = T.decl_buffer([112], "uint8") + p3 = T.decl_buffer([112], "uint8") + p4 = T.decl_buffer([32], "uint8") + p5 = T.decl_buffer([32], "uint8") + p6 = T.decl_buffer([32], "uint8") + p7 = T.decl_buffer([112], "uint8") + p8 = T.decl_buffer([32], "uint8") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 112, p2[0], dtype="handle")) @@ -134,14 +134,14 @@ def main() -> None: buffer9 = T.buffer_decl([32], "uint8") buffer10 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([128], "uint8", "global") - p2 = T.allocate([112], "uint8", "global") - p3 = T.allocate([112], "uint8", "global") - p4 = T.allocate([32], "uint8", "global") - p5 = T.allocate([32], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") - p7 = T.allocate([112], "uint8", "global") - p8 = T.allocate([32], "uint8", "global") + p1 = T.decl_buffer([128], "uint8") + p2 = T.decl_buffer([112], "uint8") + p3 = T.decl_buffer([112], "uint8") + p4 = T.decl_buffer([32], "uint8") + p5 = T.decl_buffer([32], "uint8") + p6 = T.decl_buffer([32], "uint8") + p7 = T.decl_buffer([112], "uint8") + p8 = T.decl_buffer([32], "uint8") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 112, p2[0], dtype="handle")) @@ -166,11 +166,11 @@ def main() -> None: class AllOperatorsWithoutWeights: @T.prim_func def main() -> None: - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([36], "int8") buffer2 = T.buffer_decl([9], "int8") # body - p1 = T.allocate([96], "int8", "global") + p1 = T.decl_buffer([96], "int8") T.evaluate(T.call_extern("ethosu_pooling", "int8", 3, 4, 3, 3, 0, 4, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 12, 3, 1, "int8", 3, 2, 3, 3, 0, 2, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 32, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_pooling", "int8", 3, 2, 3, 3, 0, 2, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 32, 16, 1, "int8", 3, 1, 3, 3, 0, 1, buffer2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 3, 1, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) # fmt: on @@ -188,19 +188,19 @@ def test_all_operators_without_weights(max_copy_movements): class OperatorsWithAndWithoutWeights: @T.prim_func def main() -> None: - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([97156], "int8") buffer2 = T.buffer_decl([80], "uint8") buffer3 = T.buffer_decl([64], "uint8") buffer4 = T.buffer_decl([96], "uint8") buffer5 = T.buffer_decl([32], "uint8") # body - p1 = T.allocate([390336], "int8", "global") - p2 = T.allocate([80], "uint8", "global") - p3 = T.allocate([64], "uint8", "global") - p4 = T.allocate([390336], "int8", "global") - p5 = T.allocate([96], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") + p1 = T.decl_buffer([390336], "int8") + p2 = T.decl_buffer([80], "uint8") + p3 = T.decl_buffer([64], "uint8") + p4 = T.decl_buffer([390336], "int8") + p5 = T.decl_buffer([96], "uint8") + p6 = T.decl_buffer([32], "uint8") T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle")) @@ -230,12 +230,12 @@ def main() -> None: buffer4 = T.buffer_decl([96], "uint8") buffer5 = T.buffer_decl([32], "uint8") # body - p1 = T.allocate([390336], "int8", "global") - p2 = T.allocate([80], "uint8", "global") - p3 = T.allocate([64], "uint8", "global") - p4 = T.allocate([390336], "int8", "global") - p5 = T.allocate([96], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") + p1 = T.decl_buffer([390336], "int8") + p2 = T.decl_buffer([80], "uint8") + p3 = T.decl_buffer([64], "uint8") + p4 = T.decl_buffer([390336], "int8") + p5 = T.decl_buffer([96], "uint8") + p6 = T.decl_buffer([32], "uint8") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -256,19 +256,19 @@ def test_operators_with_and_without_weights_max_copy_movements_2(): class ReferenceModule: @T.prim_func def main() -> None: - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([97156], "int8") buffer2 = T.buffer_decl([80], "uint8") buffer3 = T.buffer_decl([64], "uint8") buffer4 = T.buffer_decl([96], "uint8") buffer5 = T.buffer_decl([32], "uint8") # body - p1 = T.allocate([390336], "int8", "global") - p2 = T.allocate([80], "uint8", "global") - p3 = T.allocate([64], "uint8", "global") - p4 = T.allocate([390336], "int8", "global") - p5 = T.allocate([96], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") + p1 = T.decl_buffer([390336], "int8") + p2 = T.decl_buffer([80], "uint8") + p3 = T.decl_buffer([64], "uint8") + p4 = T.decl_buffer([390336], "int8") + p5 = T.decl_buffer([96], "uint8") + p6 = T.decl_buffer([32], "uint8") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p5[0], dtype="handle")) @@ -288,7 +288,7 @@ def main() -> None: class CopyToBufferWithLocalScope: @T.prim_func def main() -> None: - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([64], "uint8") buffer2 = T.buffer_decl([48], "uint8") buffer3 = T.buffer_decl([48], "uint8") @@ -298,13 +298,13 @@ def main() -> None: buffer7 = T.buffer_decl([256], "uint8") buffer8 = T.buffer_decl([64], "uint8") # body - p1 = T.allocate([48], "uint8", "global") - p2 = T.allocate([48], "uint8", "global") - p3 = T.allocate([256], "int8", "local") - p4 = T.allocate([256], "int8", "global") - p5 = T.allocate([16], "uint8", "global") - p6 = T.allocate([48], "uint8", "global") - p7 = T.allocate([256], "int8", "local") + p1 = T.decl_buffer([48], "uint8") + p2 = T.decl_buffer([48], "uint8") + p3 = T.decl_buffer([256], "int8", scope="local") + p4 = T.decl_buffer([256], "int8") + p5 = T.decl_buffer([16], "uint8") + p6 = T.decl_buffer([48], "uint8") + p7 = T.decl_buffer([256], "int8", scope="local") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 48, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 48, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) @@ -339,13 +339,13 @@ def main() -> None: buffer7 = T.buffer_decl([256], "uint8") buffer8 = T.buffer_decl([64], "uint8") # body - p1 = T.allocate([48], "uint8", "global") - p2 = T.allocate([48], "uint8", "global") - p3 = T.allocate([256], "int8", "local") - p4 = T.allocate([256], "int8", "global") - p5 = T.allocate([16], "uint8", "global") - p6 = T.allocate([48], "uint8", "global") - p7 = T.allocate([256], "int8", "local") + p1 = T.decl_buffer([48], "uint8") + p2 = T.decl_buffer([48], "uint8") + p3 = T.decl_buffer([256], "int8", scope="local") + p4 = T.decl_buffer([256], "int8") + p5 = T.decl_buffer([16], "uint8") + p6 = T.decl_buffer([48], "uint8") + p7 = T.decl_buffer([256], "int8", scope="local") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 48, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 48, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) @@ -412,12 +412,12 @@ def main() -> None: buffer4 = T.buffer_decl([96], "uint8") buffer5 = T.buffer_decl([32], "uint8") # body - p1 = T.allocate([390336], "int8", "global") - p2 = T.allocate([80], "uint8", "global") - p3 = T.allocate([64], "uint8", "global") - p4 = T.allocate([390336], "int8", "global") - p5 = T.allocate([96], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") + p1 = T.decl_buffer([390336], "int8") + p2 = T.decl_buffer([80], "uint8") + p3 = T.decl_buffer([64], "uint8") + p4 = T.decl_buffer([390336], "int8") + p5 = T.decl_buffer([96], "uint8") + p6 = T.decl_buffer([32], "uint8") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -438,19 +438,19 @@ def test_pass_context_option_max_copy_movements(): class ReferenceModule: @T.prim_func def main() -> None: - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([97156], "int8") buffer2 = T.buffer_decl([80], "uint8") buffer3 = T.buffer_decl([64], "uint8") buffer4 = T.buffer_decl([96], "uint8") buffer5 = T.buffer_decl([32], "uint8") # body - p1 = T.allocate([390336], "int8", "global") - p2 = T.allocate([80], "uint8", "global") - p3 = T.allocate([64], "uint8", "global") - p4 = T.allocate([390336], "int8", "global") - p5 = T.allocate([96], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") + p1 = T.decl_buffer([390336], "int8") + p2 = T.decl_buffer([80], "uint8") + p3 = T.decl_buffer([64], "uint8") + p4 = T.decl_buffer([390336], "int8") + p5 = T.decl_buffer([96], "uint8") + p6 = T.decl_buffer([32], "uint8") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p5[0], dtype="handle")) @@ -487,15 +487,15 @@ def main(placeholder: T.Buffer[97156, "int8"], placeholder_encoded: T.Buffer[208 nn_4 = T.var("int32") nn_5 = T.var("int32") # body - placeholder_d_global = T.allocate([208], "uint8", "global") - placeholder_d_global_1 = T.allocate([112], "uint8", "global") - placeholder_d_global_2 = T.allocate([96], "uint8", "global") - placeholder_d_global_3 = T.allocate([112], "uint8", "global") - ethosu_write_1 = T.allocate([195168], "int8", "global") - ethosu_write_2 = T.allocate([184800], "int8", "global") - ethosu_write_3 = T.allocate([174688], "int8", "global") - ethosu_write_4 = T.allocate([174688], "int8", "global") - ethosu_write_5 = T.allocate([174688], "int8", "global") + placeholder_d_global = T.decl_buffer([208], "uint8") + placeholder_d_global_1 = T.decl_buffer([112], "uint8") + placeholder_d_global_2 = T.decl_buffer([96], "uint8") + placeholder_d_global_3 = T.decl_buffer([112], "uint8") + ethosu_write_1 = T.decl_buffer([195168], "int8") + ethosu_write_2 = T.decl_buffer([184800], "int8") + ethosu_write_3 = T.decl_buffer([174688], "int8") + ethosu_write_4 = T.decl_buffer([174688], "int8") + ethosu_write_5 = T.decl_buffer([174688], "int8") with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None, "DataPar", ""), "pragma_compute_cycles_hint", 1792): T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 208, placeholder_d_global[0], dtype="handle")) with T.attr(T.iter_var(nn, None, "DataPar", ""), "pragma_compute_cycles_hint", 250): @@ -535,15 +535,15 @@ def main(placeholder: T.Buffer[97156, "int8"], placeholder_encoded: T.Buffer[208 nn_4 = T.var("int32") nn_5 = T.var("int32") # body - placeholder_d_global = T.allocate([208], "uint8", "global") - placeholder_d_global_1 = T.allocate([112], "uint8", "global") - placeholder_d_global_2 = T.allocate([96], "uint8", "global") - placeholder_d_global_3 = T.allocate([112], "uint8", "global") - ethosu_write_1 = T.allocate([195168], "int8", "global") - ethosu_write_2 = T.allocate([184800], "int8", "global") - ethosu_write_3 = T.allocate([174688], "int8", "global") - ethosu_write_4 = T.allocate([174688], "int8", "global") - ethosu_write_5 = T.allocate([174688], "int8", "global") + placeholder_d_global = T.decl_buffer([208], "uint8") + placeholder_d_global_1 = T.decl_buffer([112], "uint8") + placeholder_d_global_2 = T.decl_buffer([96], "uint8") + placeholder_d_global_3 = T.decl_buffer([112], "uint8") + ethosu_write_1 = T.decl_buffer([195168], "int8") + ethosu_write_2 = T.decl_buffer([184800], "int8") + ethosu_write_3 = T.decl_buffer([174688], "int8") + ethosu_write_4 = T.decl_buffer([174688], "int8") + ethosu_write_5 = T.decl_buffer([174688], "int8") with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None, "DataPar", ""), "pragma_compute_cycles_hint", 1792): T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 208, placeholder_d_global[0], dtype="handle")) with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_1, None, "DataPar", ""), "pragma_compute_cycles_hint", 1024): @@ -589,17 +589,17 @@ def main(placeholder: T.Buffer[97156, "int8"], placeholder_encoded: T.Buffer[208 nn_4 = T.var("int32") nn_5 = T.var("int32") # body - placeholder_d_d_global = T.allocate([208], "uint8", "global") - placeholder_d_d_global_1 = T.allocate([112], "uint8", "global") - placeholder_d_global = T.allocate([96], "uint8", "global") - ethosu_write_1 = T.allocate([195168], "int8", "global") - placeholder_local = T.allocate([256], "int8", "local") - ethosu_write_2 = T.allocate([184800], "int8", "global") - ethosu_write_3 = T.allocate([184800], "int8", "global") - ethosu_write_4 = T.allocate([184800], "int8", "global") - placeholder_d_local = T.allocate([256], "int8", "local") - ethosu_write_5 = T.allocate([184800], "int8", "global") - placeholder_d_d_local = T.allocate([256], "int8", "local") + placeholder_d_d_global = T.decl_buffer([208], "uint8") + placeholder_d_d_global_1 = T.decl_buffer([112], "uint8") + placeholder_d_global = T.decl_buffer([96], "uint8") + ethosu_write_1 = T.decl_buffer([195168], "int8") + placeholder_local = T.decl_buffer([256], "int8", scope="local") + ethosu_write_2 = T.decl_buffer([184800], "int8") + ethosu_write_3 = T.decl_buffer([184800], "int8") + ethosu_write_4 = T.decl_buffer([184800], "int8") + placeholder_d_local = T.decl_buffer([256], "int8", scope="local") + ethosu_write_5 = T.decl_buffer([184800], "int8") + placeholder_d_d_local = T.decl_buffer([256], "int8", scope="local") with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None, "DataPar", ""), "pragma_compute_cycles_hint", 1792): T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 208, placeholder_d_d_global[0], dtype="handle")) with T.attr(T.iter_var(nn, None, "DataPar", ""), "pragma_compute_cycles_hint", 73668): @@ -639,17 +639,17 @@ def main(placeholder: T.Buffer[97156, "int8"], placeholder_encoded: T.Buffer[208 nn_4 = T.var("int32") nn_5 = T.var("int32") # body - placeholder_d_d_global = T.allocate([208], "uint8", "global") - placeholder_d_d_global_1 = T.allocate([112], "uint8", "global") - placeholder_d_global = T.allocate([96], "uint8", "global") - ethosu_write_1 = T.allocate([195168], "int8", "global") - placeholder_local = T.allocate([256], "int8", "local") - ethosu_write_2 = T.allocate([184800], "int8", "global") - ethosu_write_3 = T.allocate([184800], "int8", "global") - ethosu_write_4 = T.allocate([184800], "int8", "global") - placeholder_d_local = T.allocate([256], "int8", "local") - ethosu_write_5 = T.allocate([184800], "int8", "global") - placeholder_d_d_local = T.allocate([256], "int8", "local") + placeholder_d_d_global = T.decl_buffer([208], "uint8") + placeholder_d_d_global_1 = T.decl_buffer([112], "uint8") + placeholder_d_global = T.decl_buffer([96], "uint8") + ethosu_write_1 = T.decl_buffer([195168], "int8") + placeholder_local = T.decl_buffer([256], "int8", scope="local") + ethosu_write_2 = T.decl_buffer([184800], "int8") + ethosu_write_3 = T.decl_buffer([184800], "int8") + ethosu_write_4 = T.decl_buffer([184800], "int8") + placeholder_d_local = T.decl_buffer([256], "int8", scope="local") + ethosu_write_5 = T.decl_buffer([184800], "int8") + placeholder_d_d_local = T.decl_buffer([256], "int8", scope="local") with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None, "DataPar", ""), "pragma_compute_cycles_hint", 1792): T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 208, placeholder_d_d_global[0], dtype="handle")) with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_1, None, "DataPar", ""), "pragma_compute_cycles_hint", 384): diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index fd9f373739e1..6ffbf22312ff 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -43,8 +43,10 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer7 = T.buffer_decl([144], "uint8") buffer8 = T.buffer_decl([32], "uint8") # body - p1 = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([144], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1_data = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.buffer_decl([160], "uint8", data=p1_data) + p2_data = T.allocate([144], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.buffer_decl([144], "uint8", data=p2_data) buffer9 = T.buffer_decl([144], "uint8", data=p1.data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 160, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 144, p2[0], dtype="handle")) @@ -69,8 +71,10 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_encoded_4_1 = T.buffer_decl([208], dtype="uint8") buffer_encoded_6_1 = T.buffer_decl([192], dtype="uint8") # body - p1 = T.allocate([208], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([192], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1_data = T.allocate([208], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.buffer_decl([208], "uint8", data=p1_data) + p2_data = T.allocate([192], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.buffer_decl([192], "uint8", data=p2_data) p3 = T.buffer_decl([192], dtype="uint8", data=p1.data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 192, p3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 192, p2[0], dtype="handle")) @@ -149,8 +153,10 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([384], "uint8") # body - p1 = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.buffer_decl([384], "uint8", data=p1_data) + p2_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.buffer_decl([384], "uint8", data=p2_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p1[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -167,8 +173,10 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), # buffer definition placeholder_encoded_1 = T.buffer_decl([464], "uint8") # body - p1 = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1_data = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.buffer_decl([464], "uint8", data=p1_data) + p2_data = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.buffer_decl([464], "uint8", data=p2_data) T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -246,7 +254,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_2 = T.buffer_decl([160], "uint8") buffer_3 = T.buffer_decl([80], "uint8") # body - ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1 = T.buffer_decl([4096], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -264,7 +273,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), placeholder_encoded_2 = T.buffer_decl([208], dtype="uint8") placeholder_encoded_3 = T.buffer_decl([96], dtype="uint8") # body - ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_2 = T.buffer_decl([4096], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded_2[0], 112, placeholder_encoded_2[112], 96, 12, placeholder_encoded_3[0], 48, placeholder_encoded_3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -340,9 +350,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,) buffer10 = T.buffer_decl([160], "uint8") buffer11 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.buffer_decl([112], "uint8", data=p1_data) + p3_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + p3 = T.buffer_decl([4096], "int8", data=p3_data) + p2_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.buffer_decl([112], "uint8", data=p2_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle")) @@ -369,9 +382,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,) buffer4 = T.buffer_decl([608], dtype="uint8") buffer5 = T.buffer_decl([160], dtype="uint8") buffer6 = T.buffer_decl([2048], dtype="int8") - p1 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.buffer_decl([128], "uint8", data=p1_data) + p2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.buffer_decl([4096], "int8", data=p2_data) + p3_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + p3 = T.buffer_decl([128], "uint8", data=p3_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p3[0], dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_hoist_allocates.py b/tests/python/contrib/test_ethosu/test_hoist_allocates.py index b54b92950180..6c6d51fa06b9 100644 --- a/tests/python/contrib/test_ethosu/test_hoist_allocates.py +++ b/tests/python/contrib/test_ethosu/test_hoist_allocates.py @@ -116,15 +116,20 @@ def main(placeholder: T.Buffer[(3402,), "int8"], placeholder_encoded: T.Buffer[( T.preflattened_buffer(placeholder_encoded_3, [3, 10], dtype="uint8") T.preflattened_buffer(ethosu_write, [1, 27, 42, 3], dtype="int8", data=ethosu_write.data) # body - placeholder_global = T.allocate([128], "uint8", "global") + placeholder_global_data = T.allocate([128], "uint8", "global") + placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data) T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 128, placeholder_global[0], dtype="handle")) - placeholder_d_global = T.allocate([32], "uint8", "global") + placeholder_d_global_data = T.allocate([32], "uint8", "global") + placeholder_d_global = T.buffer_decl([32], "uint8", data=placeholder_d_global_data) T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 32, placeholder_d_global[0], dtype="handle")) - ethosu_write_2 = T.allocate([18144], "int8", "global") + ethosu_write_2_data = T.allocate([18144], "int8", "global") + ethosu_write_2 = T.buffer_decl([18144], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 27, 42, 3, 27, 0, 42, placeholder[0], 0, 0, 0, T.float32(0.0039215646684169769), -128, "NHWC", 126, 3, 1, "int8", 27, 42, 3, 27, 0, 42, ethosu_write_2[0], 0, 0, 0, T.float32(0.031308155506849289), -128, "NHCWB16", 672, 16, 1, 2, 3, 1, 1, 1, 2, placeholder_global[0], 128, 0, placeholder_d_global[0], 32, 2, 0, 2, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - placeholder_d_global_1 = T.allocate([128], "uint8", "global") + placeholder_d_global_1_data = T.allocate([128], "uint8", "global") + placeholder_d_global_1 = T.buffer_decl([128], "uint8", data=placeholder_d_global_1_data) T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_2[0], 128, placeholder_d_global_1[0], dtype="handle")) - placeholder_d_global_2 = T.allocate([32], "uint8", "global") + placeholder_d_global_2_data = T.allocate([32], "uint8", "global") + placeholder_d_global_2 = T.buffer_decl([32], "uint8", data=placeholder_d_global_2_data) T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_3[0], 32, placeholder_d_global_2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 27, 42, 3, 27, 0, 42, ethosu_write_2[0], 0, 0, 0, T.float32(0.031308155506849289), -128, "NHCWB16", 672, 16, 1, "int8", 27, 42, 3, 27, 0, 42, ethosu_write[0], 0, 0, 0, T.float32(0.23604340851306915), -128, "NHWC", 126, 3, 1, 2, 3, 1, 1, 1, 2, placeholder_d_global_1[0], 128, 0, placeholder_d_global_2[0], 32, 2, 0, 2, 1, "CLIP", -128, 127, "TFL", "NONE", dtype="handle")) # fmt: on @@ -151,14 +156,18 @@ def main(placeholder: T.Buffer[(24,), "int8"], T_concat: T.Buffer[(24,), "int8"] T.preflattened_buffer(placeholder, [1, 2, 3, 4], dtype="int8", data=placeholder.data) T.preflattened_buffer(T_concat, [24], dtype="int8", data=T_concat.data) # body - ethosu_write = T.allocate([12], "int8", "global") + ethosu_write_data = T.allocate([12], "int8", "global") + ethosu_write = T.buffer_decl([12], "int8", data=ethosu_write_data) T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 3, 4, 1, 0, 3, placeholder[12], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 3, 4, 1, 0, 3, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - ethosu_write_1 = T.allocate([12], "int8", "global") + ethosu_write_1_data = T.allocate([12], "int8", "global") + ethosu_write_1 = T.buffer_decl([12], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 3, 4, 1, 0, 3, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 3, 4, 1, 0, 3, ethosu_write_1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_identity", "int8", 12, 1, 1, 12, 0, 1, ethosu_write_1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 1, 1, "int8", 12, 1, 1, 12, 0, 1, T_concat[12], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 1, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - ethosu_write_2 = T.allocate([12], "int8", "global") + ethosu_write_2_data = T.allocate([12], "int8", "global") + ethosu_write_2 = T.buffer_decl([12], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 3, 4, 1, 0, 3, placeholder[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 3, 4, 1, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - ethosu_write_3 = T.allocate([12], "int8", "global") + ethosu_write_3_data = T.allocate([12], "int8", "global") + ethosu_write_3 = T.buffer_decl([12], "int8", data=ethosu_write_3_data) T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 3, 4, 1, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 3, 4, 1, 0, 3, ethosu_write_3[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_identity", "int8", 12, 1, 1, 12, 0, 1, ethosu_write_3[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 1, 1, "int8", 12, 1, 1, 12, 0, 1, T_concat[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 1, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) # fmt: on @@ -185,24 +194,32 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) # body - with T.allocate([128], "uint8", "global") as placeholder_global: + with T.allocate([128], "uint8", "global") as placeholder_global_data: + placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, placeholder_global[0], dtype="handle")) - placeholder_d_global = T.allocate([32], "uint8", "global") + placeholder_d_global_data = T.allocate([32], "uint8", "global") + placeholder_d_global = T.buffer_decl([32], "uint8", data=placeholder_d_global_data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 32, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - with T.allocate([112], "uint8", "global") as placeholder_global_1: + with T.allocate([112], "uint8", "global") as placeholder_global_1_data: + placeholder_global_1 = T.buffer_decl([112], "uint8", data=placeholder_global_1_data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2[0], 112, placeholder_global_1[0], dtype="handle")) - placeholder_d_global_1 = T.allocate([32], "uint8", "global") + placeholder_d_global_1_data = T.allocate([32], "uint8", "global") + placeholder_d_global_1 = T.buffer_decl([32], "uint8", data=placeholder_d_global_1_data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3[0], 32, placeholder_d_global_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 112, 12, placeholder_d_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - with T.allocate([112], "uint8", "global") as placeholder_global_2: + with T.allocate([112], "uint8", "global") as placeholder_global_2_data: + placeholder_global_2 = T.buffer_decl([112], "uint8", data=placeholder_global_2_data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 112, placeholder_global_2[0], dtype="handle")) - placeholder_d_global_2 = T.allocate([32], "uint8", "global") + placeholder_d_global_2_data = T.allocate([32], "uint8", "global") + placeholder_d_global_2 = T.buffer_decl([32], "uint8", data=placeholder_d_global_2_data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32, placeholder_d_global_2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 112, 12, placeholder_d_global_2[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - placeholder_global_3 = T.allocate([112], "uint8", "global") + placeholder_global_3_data = T.allocate([112], "uint8", "global") + placeholder_global_3 = T.buffer_decl([112], "uint8", data=placeholder_global_3_data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6[0], 112, placeholder_global_3[0], dtype="handle")) - placeholder_d_global_3 = T.allocate([32], "uint8", "global") + placeholder_d_global_3_data = T.allocate([32], "uint8", "global") + placeholder_d_global_3 = T.buffer_decl([32], "uint8", data=placeholder_d_global_3_data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7[0], 32, placeholder_d_global_3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_3[0], 112, 12, placeholder_d_global_3[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) # fmt: on @@ -227,13 +244,20 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) # body - placeholder_global = T.allocate([128], "uint8", "global") - placeholder_global_1 = T.allocate([112], "uint8", "global") - placeholder_global_2 = T.allocate([112], "uint8", "global") - placeholder_d_global = T.allocate([32], "uint8", "global") - placeholder_d_global_1 = T.allocate([32], "uint8", "global") - placeholder_d_global_2 = T.allocate([32], "uint8", "global") - placeholder_global_3 = T.allocate([112], "uint8", "global") + placeholder_global_data = T.allocate([128], "uint8", "global") + placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data) + placeholder_global_1_data = T.allocate([112], "uint8", "global") + placeholder_global_1 = T.buffer_decl([112], "uint8", data=placeholder_global_1_data) + placeholder_global_2_data = T.allocate([112], "uint8", "global") + placeholder_global_2 = T.buffer_decl([112], "uint8", data=placeholder_global_2_data) + placeholder_d_global_data = T.allocate([32], "uint8", "global") + placeholder_d_global = T.buffer_decl([32], "uint8", data=placeholder_d_global_data) + placeholder_d_global_1_data = T.allocate([32], "uint8", "global") + placeholder_d_global_1 = T.buffer_decl([32], "uint8", data=placeholder_d_global_1_data) + placeholder_d_global_2_data = T.allocate([32], "uint8", "global") + placeholder_d_global_2 = T.buffer_decl([32], "uint8", data=placeholder_d_global_2_data) + placeholder_global_3_data = T.allocate([112], "uint8", "global") + placeholder_global_3 = T.buffer_decl([112], "uint8", data=placeholder_global_3_data) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 32, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -242,7 +266,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 112, 12, placeholder_d_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 112, placeholder_global_2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32, placeholder_d_global_2[0], dtype="handle")) - placeholder_d_global_3 = T.allocate([32], "uint8", "global") + placeholder_d_global_3_data = T.allocate([32], "uint8", "global") + placeholder_d_global_3 = T.buffer_decl([32], "uint8", data=placeholder_d_global_3_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 112, 12, placeholder_d_global_2[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6[0], 112, placeholder_global_3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7[0], 32, placeholder_d_global_3[0], dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index caf09abdb020..337b5c70d125 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -44,8 +44,10 @@ def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"]) buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([128], "uint8", "global") - p4 = T.allocate([32], "uint8", "global") + p1_data = T.allocate([128], "uint8", "global") + p1 = T.buffer_decl([128], "uint8", data=p1_data) + p4_data = T.allocate([32], "uint8", "global") + p4 = T.buffer_decl([32], "uint8", data=p4_data) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -60,7 +62,8 @@ def main(buffer2: T.Buffer[(160,), "uint8"]) -> None: buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body - p4 = T.allocate([160], "uint8", "global") + p4_data = T.allocate([160], "uint8", "global") + p4 = T.buffer_decl([160], "uint8", data=p4_data) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) # fmt: on @@ -86,14 +89,22 @@ def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"], buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([128], "uint8", "global") - p2 = T.allocate([112], "uint8", "global") - p3 = T.allocate([112], "uint8", "global") - p4 = T.allocate([32], "uint8", "global") - p5 = T.allocate([32], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") - p7 = T.allocate([112], "uint8", "global") - p8 = T.allocate([3], "uint8", "global") + p1_data = T.allocate([128], "uint8", "global") + p1 = T.buffer_decl([128], "uint8", data=p1_data) + p2_data = T.allocate([112], "uint8", "global") + p2 = T.buffer_decl([112], "uint8", data=p2_data) + p3_data = T.allocate([112], "uint8", "global") + p3 = T.buffer_decl([112], "uint8", data=p3_data) + p4_data = T.allocate([32], "uint8", "global") + p4 = T.buffer_decl([32], "uint8", data=p4_data) + p5_data = T.allocate([32], "uint8", "global") + p5 = T.buffer_decl([32], "uint8", data=p5_data) + p6_data = T.allocate([32], "uint8", "global") + p6 = T.buffer_decl([32], "uint8", data=p6_data) + p7_data = T.allocate([112], "uint8", "global") + p7 = T.buffer_decl([112], "uint8", data=p7_data) + p8_data = T.allocate([3], "uint8", "global") + p8 = T.buffer_decl([3], "uint8", data=p8_data) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 112, p2[0], dtype="handle")) @@ -117,10 +128,14 @@ def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body - p4 = T.allocate([160], "uint8", "global") - p7 = T.allocate([144], "uint8", "global") - p10 = T.allocate([144], "uint8", "global") - p11 = T.allocate([144], "uint8", "global") + p4_data = T.allocate([160], "uint8", "global") + p4 = T.buffer_decl([160], "uint8", data=p4_data) + p7_data = T.allocate([144], "uint8", "global") + p7 = T.buffer_decl([144], "uint8", data=p7_data) + p10_data = T.allocate([144], "uint8", "global") + p10 = T.buffer_decl([144], "uint8", data=p10_data) + p11_data = T.allocate([144], "uint8", "global") + p11 = T.buffer_decl([144], "uint8", data=p11_data) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -159,13 +174,15 @@ def test_operators_with_and_without_weights(): class InputModule: @T.prim_func def main(buffer2: T.Buffer[(80,), "uint8"], buffer3: T.Buffer[(64,), "uint8"]) -> None: - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer0 = T.buffer_decl([390336], "int8") buffer1 = T.buffer_decl([97156], "int8") buffer6 = T.buffer_decl([390336], "int8") # body - p2 = T.allocate([80], "uint8", "global") - p3 = T.allocate([64], "uint8", "global") + p2_data = T.allocate([80], "uint8", "global") + p2 = T.buffer_decl([80], "uint8", data=p2_data) + p3_data = T.allocate([64], "uint8", "global") + p3 = T.buffer_decl([64], "uint8", data=p3_data) T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle")) @@ -176,12 +193,13 @@ def main(buffer2: T.Buffer[(80,), "uint8"], buffer3: T.Buffer[(64,), "uint8"]) - class ReferenceModule: @T.prim_func def main(buffer2: T.Buffer[(144,), "uint8"]) -> None: - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer0 = T.buffer_decl([390336], "int8") buffer1 = T.buffer_decl([97156], "int8") buffer6 = T.buffer_decl([390336], "int8") # body - p3 = T.allocate([144], "uint8", "global") + p3_data = T.allocate([144], "uint8", "global") + p3 = T.buffer_decl([144], "uint8", data=p3_data) T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 144, p3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, buffer6[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p3[0], 80, 0, p3[80], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -203,8 +221,8 @@ def test_copy_to_buffer_with_local_scope(): @tvm.script.ir_module class InputModule: @T.prim_func - def main(buffer1: T.Buffer[(64,), "uint8"], - buffer2: T.Buffer[(48,), "uint8"], + def main(buffer1: T.Buffer[(64,), "uint8"], + buffer2: T.Buffer[(48,), "uint8"], buffer3: T.Buffer[(256,), "uint8"], buffer4: T.Buffer[(256,), "uint8"], buffer5: T.Buffer[(16,), "uint8"], @@ -215,12 +233,18 @@ def main(buffer1: T.Buffer[(64,), "uint8"], ) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # body - p1 = T.allocate([48], "uint8", "global") - p2 = T.allocate([48], "uint8", "global") - p3 = T.allocate([256], "int8", "local") - p5 = T.allocate([16], "uint8", "global") - p6 = T.allocate([48], "uint8", "global") - p7 = T.allocate([256], "int8", "local") + p1_data = T.allocate([48], "uint8", "global") + p1 = T.buffer_decl([48], "uint8", data=p1_data) + p2_data = T.allocate([48], "uint8", "global") + p2 = T.buffer_decl([48], "uint8", data=p2_data) + p3_data = T.allocate([256], "int8", "local") + p3 = T.buffer_decl([256], "int8", data=p3_data, scope="local") + p5_data = T.allocate([16], "uint8", "global") + p5 = T.buffer_decl([16], "uint8", data=p5_data) + p6_data = T.allocate([48], "uint8", "global") + p6 = T.buffer_decl([48], "uint8", data=p6_data) + p7_data = T.allocate([256], "int8", "local") + p7 = T.buffer_decl([256], "int8", data=p7_data, scope="local") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 48, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 48, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) # Local @@ -234,8 +258,8 @@ def main(buffer1: T.Buffer[(64,), "uint8"], @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(buffer1: T.Buffer[(64,), "uint8"], - buffer2: T.Buffer[(96,), "uint8"], + def main(buffer1: T.Buffer[(64,), "uint8"], + buffer2: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(256,), "uint8"], buffer5: T.Buffer[(64,), "uint8"], buffer7: T.Buffer[(256,), "uint8"], @@ -244,10 +268,14 @@ def main(buffer1: T.Buffer[(64,), "uint8"], ) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # body - p1 = T.allocate([96], "uint8", "global") - p2 = T.allocate([64], "uint8", "global") - p3 = T.allocate([256], "int8", "local") - p7 = T.allocate([256], "int8", "local") + p1_data = T.allocate([96], "uint8", "global") + p1 = T.buffer_decl([96], "uint8", data=p1_data) + p2_data = T.allocate([64], "uint8", "global") + p2 = T.buffer_decl([64], "uint8", data=p2_data) + p3_data = T.allocate([256], "int8", "local") + p3 = T.buffer_decl([256], "int8", data=p3_data, scope="local") + p7_data = T.allocate([256], "int8", "local") + p7 = T.buffer_decl([256], "int8", data=p7_data, scope="local") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) # Local T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 64, p2[0], dtype="handle")) @@ -287,10 +315,11 @@ def main() -> None: placeholder = T.buffer_decl([20], "int8") ethosu_write = T.buffer_decl([16], "int8") # body - ethosu_write_4 = T.allocate([16], "int8", "global") + ethosu_write_4_data = T.allocate([16], "int8", "global") + ethosu_write_4 = T.buffer_decl([16], "int8", data=ethosu_write_4_data) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 1, 4, 4, 1, 0, 4, placeholder[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "int8", 1, 4, 1, 1, 0, 4, placeholder[16], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 1, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "MAX", 0, "CLIP", -128, 127, "TFL", 1, 4, 4, dtype="handle")) T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - + @tvm.script.ir_module class ReferenceModule: @T.prim_func @@ -300,7 +329,8 @@ def main() -> None: placeholder = T.buffer_decl([20], "int8") ethosu_write = T.buffer_decl([16], "int8") # body - ethosu_write_4 = T.allocate([16], "int8", "global") + ethosu_write_4_data = T.allocate([16], "int8", "global") + ethosu_write_4 = T.buffer_decl([16], "int8", data=ethosu_write_4_data) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 1, 4, 4, 1, 0, 4, placeholder[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "int8", 1, 4, 1, 1, 0, 4, placeholder[16], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 1, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "MAX", 0, "CLIP", -128, 127, "TFL", 1, 4, 4, dtype="handle")) T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) # fmt: on @@ -324,8 +354,10 @@ def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"]) buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([128], "uint8", "global") - p4 = T.allocate([32], "uint8", "global") + p1_data = T.allocate([128], "uint8", "global") + p1 = T.buffer_decl([128], "uint8", data=p1_data) + p4_data = T.allocate([32], "uint8", "global") + p4 = T.buffer_decl([32], "uint8", data=p4_data) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -343,7 +375,8 @@ def main(buffer2: T.Buffer[(160,), "uint8"]) -> None: buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body - p5 = T.allocate([160], "uint8", "global") + p5_data = T.allocate([160], "uint8", "global") + p5 = T.buffer_decl([160], "uint8", data=p5_data) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5[0], 128, 12, p5[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle")) @@ -373,8 +406,10 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) # body - p1 = T.allocate([368], "uint8", "global") - p2 = T.allocate([96], "uint8", "global") + p1_data = T.allocate([368], "uint8", "global") + p1 = T.buffer_decl([368], "uint8", data=p1_data) + p2_data = T.allocate([96], "uint8", "global") + p2 = T.buffer_decl([96], "uint8", data=p2_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -388,7 +423,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # body - p1 = T.allocate([464], "uint8", "global") + p1_data = T.allocate([464], "uint8", "global") + p1 = T.buffer_decl([464], "uint8", data=p1_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -428,14 +464,22 @@ def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"], buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([128], "uint8", "global") - p2 = T.allocate([112], "uint8", "global") - p3 = T.allocate([112], "uint8", "global") - p4 = T.allocate([32], "uint8", "global") - p5 = T.allocate([32], "uint8", "global") - p6 = T.allocate([32], "uint8", "global") - p7 = T.allocate([112], "uint8", "global") - p8 = T.allocate([3], "uint8", "global") + p1_data = T.allocate([128], "uint8", "global") + p1 = T.buffer_decl([128], "uint8", data=p1_data) + p2_data = T.allocate([112], "uint8", "global") + p2 = T.buffer_decl([112], "uint8", data=p2_data) + p3_data = T.allocate([112], "uint8", "global") + p3 = T.buffer_decl([112], "uint8", data=p3_data) + p4_data = T.allocate([32], "uint8", "global") + p4 = T.buffer_decl([32], "uint8", data=p4_data) + p5_data = T.allocate([32], "uint8", "global") + p5 = T.buffer_decl([32], "uint8", data=p5_data) + p6_data = T.allocate([32], "uint8", "global") + p6 = T.buffer_decl([32], "uint8", data=p6_data) + p7_data = T.allocate([112], "uint8", "global") + p7 = T.buffer_decl([112], "uint8", data=p7_data) + p8_data = T.allocate([3], "uint8", "global") + p8 = T.buffer_decl([3], "uint8", data=p8_data) with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 100): T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) with T.attr(T.iter_var(v1b, None, "DataPar", ""), "pragma_compute_cycles_hint", 101): @@ -479,10 +523,14 @@ def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body - p4 = T.allocate([160], "uint8", "global") - p7 = T.allocate([144], "uint8", "global") - p10 = T.allocate([144], "uint8", "global") - p11 = T.allocate([144], "uint8", "global") + p4_data = T.allocate([160], "uint8", "global") + p4 = T.buffer_decl([160], "uint8", data=p4_data) + p7_data = T.allocate([144], "uint8", "global") + p7 = T.buffer_decl([144], "uint8", data=p7_data) + p10_data = T.allocate([144], "uint8", "global") + p10 = T.buffer_decl([144], "uint8", data=p10_data) + p11_data = T.allocate([144], "uint8", "global") + p11 = T.buffer_decl([144], "uint8", data=p11_data) with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 201): T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 205): diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index d2c759a0ae4d..e6414c24d4a3 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -42,7 +42,8 @@ def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,) buffer_6 = T.buffer_decl([2992], "uint8") buffer_7 = T.buffer_decl([160], "uint8") # body - T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) + T_concat_1_data = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) + T_concat_1 = T.buffer_decl([2816], "int8", data=T_concat_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat[352], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_3[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T_concat_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer_4[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_5[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 46a3c5a15bf5..ae46057369e0 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -374,7 +374,8 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([160], "uint8") # body - ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2_data = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2 = T.buffer_decl([1024], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[12], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -393,7 +394,8 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, buffer_2 = T.buffer_decl([1312], "uint8") buffer_3 = T.buffer_decl([2608], "uint8") # body - ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2_data = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2 = T.buffer_decl([1536], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, T.int8(-1), T.int8(-1), 12, buffer[0], 80, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[48], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -412,7 +414,8 @@ def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640, buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([880], "uint8") # body - ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2_data = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2 = T.buffer_decl([2560], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, placeholder_5[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -433,7 +436,8 @@ def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(204 buffer_2 = T.buffer_decl([272], "uint8") buffer_3 = T.buffer_decl([11040], "uint8") # body - ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2_data = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2 = T.buffer_decl((2304,), "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_2[0], 272, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[256], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -452,7 +456,8 @@ def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), buffer_2 = T.buffer_decl([304], "uint8") buffer_3 = T.buffer_decl([80], "uint8") # body - ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1 = T.buffer_decl([4096], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) @@ -471,7 +476,8 @@ def main(placeholder: T.Buffer[(1024,), "int8"], ethosu_write: T.Buffer[(32768,) buffer_2 = T.buffer_decl([11040], "uint8") buffer_3 = T.buffer_decl([272], "uint8") # body - ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1_data = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1 = T.buffer_decl([12288], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, buffer_2[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_3[0], 272, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) __tvm_meta__ = None diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 6b97b38d80e6..8c7ff35272ef 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -36,7 +36,8 @@ def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(204 T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer_1 = T.buffer_decl([384], "uint8") # body - placeholder_global = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_global_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_global = T.buffer_decl([384], "uint8", data=placeholder_global_data) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 384, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_global[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -78,8 +79,10 @@ def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(409 buffer = T.buffer_decl([528], "uint8") buffer_2 = T.buffer_decl([336], "uint8") # body - placeholder_d_global = T.allocate([528], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global_1 = T.allocate([336], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global_data = T.allocate([528], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global = T.buffer_decl([528], "uint8", data=placeholder_d_global_data) + placeholder_d_global_1_data = T.allocate([336], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global_1 = T.buffer_decl([336], "uint8", data=placeholder_d_global_1_data) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 528, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 336, placeholder_d_global_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global[416], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index ba050de2b473..254abab644a2 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -184,10 +184,14 @@ def main(placeholder: T.Buffer[(301056,), "int8"], ethosu_write: T.Buffer[(75264 T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([2848], "uint8") buffer3 = T.buffer_decl([976], "uint8") - p1 = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([976], "uint8", "global", annotations={"disable_lower_builtin":True}) - p5 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) - p6 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) + p1_data = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.buffer_decl([2848], "uint8", data=p1_data) + p2_data = T.allocate([976], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.buffer_decl([976], "uint8", data=p2_data) + p5_data = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) + p5 = T.buffer_decl([75264], "int8", data=p5_data) + p6_data = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) + p6 = T.buffer_decl([75264], "int8", data=p6_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 2848, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 976, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p1[2608], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index e1a0e143281b..f8a84aa08367 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -56,8 +56,8 @@ def main(placeholder_6: T.Buffer[(192,), "int8"], ethosu_conv2d_1: T.Buffer[(512 placeholder_8 = T.buffer_decl([1], "uint8") placeholder_5 = T.buffer_decl([1], "uint8") # body - ethosu_conv2d_2 = T.allocate([1024], "uint8", "global") - ethosu_conv2d_3 = T.allocate([2048], "uint8", "global") + ethosu_conv2d_2 = T.decl_buffer([1024], "uint8") + ethosu_conv2d_3 = T.decl_buffer([2048], "uint8") T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_8[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="uint8")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_9[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_5[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_8[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) @@ -76,8 +76,8 @@ def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_conv2d_1: T.Buffer[(20 placeholder_5 = T.buffer_decl([1], "int32") placeholder_4 = T.buffer_decl([1], "uint8") # body - placeholder_global = T.allocate([256], "uint8", "global") - placeholder_d_global = T.allocate([8], "int32", "global") + placeholder_global = T.decl_buffer([256], "uint8") + placeholder_d_global = T.decl_buffer([8], "int32") T.evaluate(T.call_extern("ethosu_copy", placeholder_4[0], 256, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", placeholder_5[0], 8, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -110,8 +110,10 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_6.name: buffer_6, buffer_7.name: buffer_7}}) # body - placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global = T.decl_buffer([128], "uint8", data=placeholder_global_data) + placeholder_d_global_data = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global = T.decl_buffer([32], "uint8", data=placeholder_d_global_data) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 128, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 32, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -158,9 +160,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_8.name: buffer_8, buffer_9.name: buffer_9}}) # body - ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1 = T.buffer_decl([4096], "int8", data=ethosu_write_1_data) + placeholder_global_data = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global = T.buffer_decl([80], "uint8", data=placeholder_global_data) + placeholder_d_global_data = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global = T.buffer_decl([32], "uint8", data=placeholder_d_global_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) @@ -678,10 +683,10 @@ def main(placeholder_4: T.Buffer[(2048,), "int8"], ethosu_write_1: T.Buffer[(16, buffer_1.name: buffer_1, buffer_2.name: buffer_2}}) # body - placeholder_global = T.allocate([272], "uint8", "global") - placeholder_d_global = T.allocate([160], "uint8", "global") - ethosu_write_2 = T.allocate([16], "int16", "global") - placeholder_d_global_1 = T.allocate([1], "int16", "global") + placeholder_global = T.decl_buffer([272], "uint8") + placeholder_d_global = T.decl_buffer([160], "uint8") + ethosu_write_2 = T.decl_buffer([16], "int16") + placeholder_d_global_1 = T.decl_buffer([1], "int16") T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 272, placeholder_global[0], dtype="uint8")) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 160, placeholder_d_global[0], dtype="uint8")) T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 16, 16, 8, 0, 16, placeholder_4[0], 0, 0, 0, T.float32(0.0039215548895299435), -128, "NHWC", 256, 16, 1, "int16", 1, 1, 16, 1, 0, 1, ethosu_write_2[0], 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, 16, 8, 1, 1, 1, 1, placeholder_global[0], 272, 0, placeholder_d_global[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="int16")) diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index 0b1e0f402b9d..e7632561c05c 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -63,9 +63,9 @@ def main(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) - B_local = T.allocate([64], "float32", "local") - Apad_shared = T.allocate([512], "float32", "shared") - Apad_shared_local = T.allocate([8], "float32", "local") + B_local = T.decl_buffer([64], "float32", scope="local") + Apad_shared = T.decl_buffer([512], "float32", scope="shared") + Apad_shared_local = T.decl_buffer([8], "float32", scope="local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) @@ -105,9 +105,9 @@ def main(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) - B_local = T.allocate([6400000], "float32", "local") - Apad_shared = T.allocate([512], "float32", "shared") - Apad_shared_local = T.allocate([8], "float32", "local") + B_local = T.decl_buffer([6400000], "float32", scope="local") + Apad_shared = T.decl_buffer([512], "float32", scope="shared") + Apad_shared_local = T.decl_buffer([8], "float32", scope="local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) @@ -151,9 +151,9 @@ def main(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) - B_local = T.allocate([64], "float32", "local") - Apad_shared = T.allocate([512000], "float32", "shared") - Apad_shared_local = T.allocate([8], "float32", "local") + B_local = T.decl_buffer([64], "float32", scope="local") + Apad_shared = T.decl_buffer([512000], "float32", scope="shared") + Apad_shared_local = T.decl_buffer([8], "float32", scope="local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) @@ -197,9 +197,9 @@ def main(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) - B_local = T.allocate([64], "float32", "local") - Apad_shared = T.allocate([512], "float32", "shared") - Apad_shared_local = T.allocate([8], "float32", "local") + B_local = T.decl_buffer([64], "float32", scope="local") + Apad_shared = T.decl_buffer([512], "float32", scope="shared") + Apad_shared_local = T.decl_buffer([8], "float32", scope="local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) diff --git a/tests/python/unittest/test_tir_analysis_calculate_workspace.py b/tests/python/unittest/test_tir_analysis_calculate_workspace.py index 1d78458b930d..12c892a04b07 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_workspace.py +++ b/tests/python/unittest/test_tir_analysis_calculate_workspace.py @@ -31,8 +31,8 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.hand placeholder_149 = T.match_buffer(placeholder_146, [512], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_49 = T.match_buffer(T_cast_48, [100352], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_22 = T.allocate([131072], "int16", "global") - DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") + PaddedInput_22 = T.decl_buffer([131072], "int16") + DepthwiseConv2d_9 = T.decl_buffer([100352], "int32") for i1_29, i2_39, i3_40 in T.grid(16, 16, 512): PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), placeholder_147[((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)], T.int16(0), dtype="int16") for i_9, j_9, c_9 in T.grid(14, 14, 512): @@ -63,25 +63,25 @@ def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handl T_cast_77 = T.match_buffer(T_cast_76, [100352], dtype="int16", elem_offset=0, align=64, offset_factor=1) sid_21 = T.allocate_const([0,1,2,3,4,5,6,7], "int8", [8]) # body - PaddedInput_25 = T.allocate([131072], "int16", "global") + PaddedInput_25 = T.decl_buffer([131072], "int16") for i1_35, i2_46, i3_47 in T.grid(16, 16, 512): PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), placeholder_165[((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)], T.int16(0), dtype="int16") - T_add_11 = T.allocate([100352], "int32", "global") - with T.allocate([100352], "int32", "global") as DepthwiseConv2d_11: + T_add_11 = T.decl_buffer([100352], "int32") + with T.decl_buffer([100352], "int32") as DepthwiseConv2d_11: for i_11, j_11, c_11 in T.grid(14, 14, 512): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0 for di_11, dj_11 in T.grid(3, 3): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] + (PaddedInput_25[(((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)].astype("int32")*placeholder_166[(((di_11*1536) + (dj_11*512)) + c_11)].astype("int32"))) for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512): T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (DepthwiseConv2d_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] + placeholder_167[ax3_47]) - compute_22 = T.allocate([100352], "int32", "global") - with T.allocate([100352], "int32", "global") as T_cast_78: + compute_22 = T.decl_buffer([100352], "int32") + with T.decl_buffer([100352], "int32") as T_cast_78: for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512): T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T_add_11[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] for i1_36, i2_47, i3_48 in T.grid(14, 14, 512): compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T_cast_78[(((i1_36*7168) + (i2_47*512)) + i3_48)], 1948805937, 31, -5, dtype="int32") - T_cast_79 = T.allocate([100352], "uint8", "global") - with T.allocate([100352], "int32", "global") as compute_23: + T_cast_79 = T.decl_buffer([100352], "uint8") + with T.decl_buffer([100352], "int32") as compute_23: for i1_37, i2_48, i3_49 in T.grid(14, 14, 512): compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(compute_22[(((i1_37*7168) + (i2_48*512)) + i3_49)], 255), 0) for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512): diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 49121614ffa0..344f37a23677 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -52,7 +52,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: with T.block(): T.reads([]) T.writes(B[0:16, 0:16]) - A = T.allocate([256], "float32", "global") + A = T.decl_buffer([256], "float32") for i, j in T.grid(16, 16): A[i * 16 + j] = 1 for i in range(0, 16): diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py index 23405fdee98a..bee9b7b48020 100644 --- a/tests/python/unittest/test_tir_ptx_mma.py +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -36,9 +36,9 @@ def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([1], "float64", scope="local") - MultiB = T.allocate([1], "float64", scope="local") - Accum = T.allocate([2], "float64", scope="local") + MultiA = T.decl_buffer([1], "float64", scope="local") + MultiB = T.decl_buffer([1], "float64", scope="local") + Accum = T.decl_buffer([2], "float64", scope="local") for i in range(2): Accum[i] = T.float64(0) @@ -106,9 +106,9 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([4], "float16", scope="local") - MultiB = T.allocate([4], "float16", scope="local") - Accum = T.allocate([8], "float16", scope="local") + MultiA = T.decl_buffer([4], "float16", scope="local") + MultiB = T.decl_buffer([4], "float16", scope="local") + Accum = T.decl_buffer([8], "float16", scope="local") for i in range(8): Accum[i] = T.float32(0) @@ -187,9 +187,10 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([4], "float16", scope="local") - MultiB = T.allocate([4], "float16", scope="local") - Accum = T.allocate([8], "float32", scope="local") + MultiA = T.decl_buffer([4], "float16", scope="local") + MultiB = T.decl_buffer([4], "float16", scope="local") + Accum = T.decl_buffer([8], "float32", scope="local") + for i in range(8): Accum[i] = T.float32(0) @@ -274,9 +275,9 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([4], "int8", scope="local") - MultiB = T.allocate([4], "int8", scope="local") - Accum = T.allocate([2], "int32", scope="local") + MultiA = T.decl_buffer([4], "int8", scope="local") + MultiB = T.decl_buffer([4], "int8", scope="local") + Accum = T.decl_buffer([2], "int32", scope="local") for i in range(2): Accum[i] = T.int32(0) @@ -350,9 +351,9 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([4], "int8", scope="local") - MultiB = T.allocate([4], "uint8", scope="local") - Accum = T.allocate([2], "int32", scope="local") + MultiA = T.decl_buffer([4], "int8", scope="local") + MultiB = T.decl_buffer([4], "uint8", scope="local") + Accum = T.decl_buffer([2], "int32", scope="local") for i in range(2): Accum[i] = T.int32(0) @@ -426,9 +427,9 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([8], "int4", scope="local") - MultiB = T.allocate([8], "int4", scope="local") - Accum = T.allocate([2], "int32", scope="local") + MultiA = T.decl_buffer([8], "int4", scope="local") + MultiB = T.decl_buffer([8], "int4", scope="local") + Accum = T.decl_buffer([2], "int32", scope="local") for i in range(2): Accum[i] = T.int32(0) @@ -494,9 +495,9 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([8], "int4", scope="local") - MultiB = T.allocate([8], "uint4", scope="local") - Accum = T.allocate([2], "int32", scope="local") + MultiA = T.decl_buffer([8], "int4", scope="local") + MultiB = T.decl_buffer([8], "uint4", scope="local") + Accum = T.decl_buffer([2], "int32", scope="local") for i in range(2): Accum[i] = T.int32(0) @@ -562,9 +563,9 @@ def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle) T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([4], "float16", scope="local") - MultiB = T.allocate([2], "float16", scope="local") - Accum = T.allocate([4], "float32", scope="local") + MultiA = T.decl_buffer([4], "float16", scope="local") + MultiB = T.decl_buffer([2], "float16", scope="local") + Accum = T.decl_buffer([4], "float32", scope="local") for i in range(4): Accum[i] = T.float32(0) @@ -640,9 +641,9 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([8], "float16", scope="local") - MultiB = T.allocate([4], "float16", scope="local") - Accum = T.allocate([4], "float16", scope="local") + MultiA = T.decl_buffer([8], "float16", scope="local") + MultiB = T.decl_buffer([4], "float16", scope="local") + Accum = T.decl_buffer([4], "float16", scope="local") for i in range(4): Accum[i] = T.float32(0) @@ -722,9 +723,9 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([8], "float16", scope="local") - MultiB = T.allocate([4], "float16", scope="local") - Accum = T.allocate([4], "float32", scope="local") + MultiA = T.decl_buffer([8], "float16", scope="local") + MultiB = T.decl_buffer([4], "float16", scope="local") + Accum = T.decl_buffer([4], "float32", scope="local") for i in range(4): Accum[i] = T.float32(0) @@ -804,9 +805,9 @@ def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([8], "int8", scope="local") - MultiB = T.allocate([4], "int8", scope="local") - Accum = T.allocate([4], "int32", scope="local") + MultiA = T.decl_buffer([8], "int8", scope="local") + MultiB = T.decl_buffer([4], "int8", scope="local") + Accum = T.decl_buffer([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) @@ -886,9 +887,9 @@ def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([8], "int8", scope="local") - MultiB = T.allocate([4], "uint8", scope="local") - Accum = T.allocate([4], "int32", scope="local") + MultiA = T.decl_buffer([8], "int8", scope="local") + MultiB = T.decl_buffer([4], "uint8", scope="local") + Accum = T.decl_buffer([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) @@ -968,9 +969,9 @@ def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([16], "int8", scope="local") - MultiB = T.allocate([8], "int8", scope="local") - Accum = T.allocate([4], "int32", scope="local") + MultiA = T.decl_buffer([16], "int8", scope="local") + MultiB = T.decl_buffer([8], "int8", scope="local") + Accum = T.decl_buffer([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) @@ -1050,9 +1051,9 @@ def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([16], "int8", scope="local") - MultiB = T.allocate([8], "uint8", scope="local") - Accum = T.allocate([4], "int32", scope="local") + MultiA = T.decl_buffer([16], "int8", scope="local") + MultiB = T.decl_buffer([8], "uint8", scope="local") + Accum = T.decl_buffer([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) @@ -1132,9 +1133,9 @@ def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([32], "int4", scope="local") - MultiB = T.allocate([16], "int4", scope="local") - Accum = T.allocate([4], "int32", scope="local") + MultiA = T.decl_buffer([32], "int4", scope="local") + MultiB = T.decl_buffer([16], "int4", scope="local") + Accum = T.decl_buffer([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) @@ -1206,9 +1207,9 @@ def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([32], "int4", scope="local") - MultiB = T.allocate([16], "uint4", scope="local") - Accum = T.allocate([4], "int32", scope="local") + MultiA = T.decl_buffer([32], "int4", scope="local") + MultiB = T.decl_buffer([16], "uint4", scope="local") + Accum = T.decl_buffer([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) @@ -1280,9 +1281,9 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - MultiA = T.allocate([128], "int1", scope="local") - MultiB = T.allocate([64], "int1", scope="local") - Accum = T.allocate([4], "int32", scope="local") + MultiA = T.decl_buffer([128], "int1", scope="local") + MultiB = T.decl_buffer([64], "int1", scope="local") + Accum = T.decl_buffer([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py index 321cd28ff6f7..24170b4898f9 100644 --- a/tests/python/unittest/test_tir_ptx_mma_sp.py +++ b/tests/python/unittest/test_tir_ptx_mma_sp.py @@ -52,10 +52,10 @@ def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - multi_a = T.allocate([4], "float16", scope="local") - multi_b = T.allocate([4], "float16", scope="local") - accum = T.allocate([4], "float16", scope="local") - meta_local = T.allocate([1], "uint32", scope="local") + multi_a = T.decl_buffer([4], "float16", scope="local") + multi_b = T.decl_buffer([4], "float16", scope="local") + accum = T.decl_buffer([4], "float16", scope="local") + meta_local = T.decl_buffer([1], "uint32", scope="local") for i in range(4): accum[i] = T.float16(0) @@ -106,10 +106,10 @@ def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - multi_a = T.allocate([4], "float16", scope="local") - multi_b = T.allocate([4], "float16", scope="local") - accum = T.allocate([4], "float32", scope="local") - meta_local = T.allocate([1], "uint32", scope="local") + multi_a = T.decl_buffer([4], "float16", scope="local") + multi_b = T.decl_buffer([4], "float16", scope="local") + accum = T.decl_buffer([4], "float32", scope="local") + meta_local = T.decl_buffer([1], "uint32", scope="local") for i in range(4): accum[i] = T.float16(0) @@ -160,10 +160,10 @@ def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - multi_a = T.allocate([8], "float16", scope="local") - multi_b = T.allocate([8], "float16", scope="local") - accum = T.allocate([4], "float16", scope="local") - meta_local = T.allocate([1], "uint32", scope="local") + multi_a = T.decl_buffer([8], "float16", scope="local") + multi_b = T.decl_buffer([8], "float16", scope="local") + accum = T.decl_buffer([4], "float16", scope="local") + meta_local = T.decl_buffer([1], "uint32", scope="local") for i in range(4): accum[i] = T.float16(0) @@ -214,10 +214,10 @@ def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) - multi_a = T.allocate([8], "float16", scope="local") - multi_b = T.allocate([8], "float16", scope="local") - accum = T.allocate([4], "float32", scope="local") - meta_local = T.allocate([1], "uint32", scope="local") + multi_a = T.decl_buffer([8], "float16", scope="local") + multi_b = T.decl_buffer([8], "float16", scope="local") + accum = T.decl_buffer([4], "float32", scope="local") + meta_local = T.decl_buffer([1], "uint32", scope="local") for i in range(4): accum[i] = T.float16(0) diff --git a/tests/python/unittest/test_tir_renew_defs.py b/tests/python/unittest/test_tir_renew_defs.py index 36cc52c16935..28b440a608dc 100644 --- a/tests/python/unittest/test_tir_renew_defs.py +++ b/tests/python/unittest/test_tir_renew_defs.py @@ -135,7 +135,8 @@ def test_undefined_buffer(): @T.prim_func def access_alloc(): # Buffer A should be remapped - A = T.allocate([128], "float16", "global") + A_data = T.allocate([128], "float16", "global") + A = T.buffer_decl(shape=[128], dtype="float16", data=A_data) # check if buffer var also get remapped T.evaluate(A.data) for i in range(128): diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index d5feb21f0db7..4bb13ed77ad8 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -234,7 +234,7 @@ def test_buffer_storage_scope(): buffer_local_0 = tvm.tir.decl_buffer((10, 10), "float32", scope="local") buffer_local_1 = tvm.tir.decl_buffer((10, 10), "float32", scope="local") - buffer_global = tvm.tir.decl_buffer((10, 10), "float32", scope="global") + buffer_global = tvm.tir.decl_buffer((10, 10), "float32") buffer_empty = tvm.tir.decl_buffer((10, 10), "float32", scope="") func0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_local_0}) diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index 1a3afdd4c1e2..e08f04fa1f25 100644 --- a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -31,13 +31,13 @@ def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T. placeholder_35 = T.match_buffer(placeholder_32, [16], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_9 = T.match_buffer(T_cast_8, [12544], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_3 = T.allocate([150528], "int16", "global") + PaddedInput_3 = T.decl_buffer([150528], "int16") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): - Conv2dOutput_3 = T.allocate([1], "int32", "global") + Conv2dOutput_3 = T.decl_buffer([1], "int32") Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32"))) diff --git a/tests/python/unittest/test_tir_transform_extract_constants.py b/tests/python/unittest/test_tir_transform_extract_constants.py index 82f4f6515c09..5de06e38a557 100644 --- a/tests/python/unittest/test_tir_transform_extract_constants.py +++ b/tests/python/unittest/test_tir_transform_extract_constants.py @@ -27,7 +27,8 @@ class Module4: def constant1(a: T.handle) -> None: A = T.match_buffer(a, (10), "int32") B = T.alloc_buffer((10), "int32") - K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K = T.buffer_decl(shape=(10), dtype="int32", data=K_data) for x in T.serial(0, 10): B[x] = A[x] + K[x] @@ -35,7 +36,8 @@ def constant1(a: T.handle) -> None: def constant2(a: T.handle) -> None: A = T.match_buffer(a, (10), "int32") B = T.alloc_buffer((10), "int32") - K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K = T.buffer_decl(shape=(10), dtype="int32", data=K_data) for x in T.serial(0, 10): B[x] = A[x] + K[x] @@ -43,7 +45,8 @@ def constant2(a: T.handle) -> None: def constant3(a: T.handle) -> None: A = T.match_buffer(a, (10), "int32") B = T.alloc_buffer((10), "int32") - K = T.allocate_const([1, 2, 3, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K_data = T.allocate_const([1, 2, 3, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K = T.buffer_decl(shape=(10), dtype="int32", data=K_data) for x in T.serial(0, 10): B[x] = A[x] + K[x] diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index a1195a9d2a65..4cdf71889eee 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -33,7 +33,8 @@ def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in T.serial(0, 16): - B_new = T.allocate([1, 16], "float32", "global") + B_new_data = T.allocate([1, 16], "float32", "global") + B_new = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_new_data) for j in T.serial(0, 16): B_new[0, j] = A[i, j] + 1.0 for j in T.serial(0, 16): @@ -47,7 +48,8 @@ def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) for i in T.serial(0, 16): - B_new = T.allocate([16], "float32", "global") + B_new_data = T.allocate([16], "float32", "global") + B_new = T.buffer_decl(shape=[16], dtype="float32", data=B_new_data) for j in T.serial(0, 16): B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): @@ -66,7 +68,8 @@ def gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.allocate([1, 16], "float32", "local") + B_data = T.allocate([1, 16], "float32", "local") + B = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_data, scope="local") for j in range(0, 16): B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): @@ -87,7 +90,8 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") + B_data = T.allocate([16], "float32", "local") + B = T.buffer_decl(shape=[16], dtype="float32", data=B_data, scope="local") for j in range(0, 16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): @@ -100,7 +104,8 @@ def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B = T.allocate([m], "float32", "global") + B_data = T.allocate([m], "float32", "global") + B = T.buffer_decl(shape=[m], dtype="float32", data=B_data) for j in range(0, m): B[j] = A[i, j] + 1.0 for j in range(0, m): @@ -115,7 +120,8 @@ def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> T.preflattened_buffer(C, (n, m), "float32", data=C.data) for i in range(0, n): - B = T.allocate([m], "float32", "global") + B_data = T.allocate([m], "float32", "global") + B = T.buffer_decl(shape=[m], dtype="float32", data=B_data) for j in range(0, m): B[j] = A[i * m + j] + 1.0 for j in range(0, m): @@ -128,8 +134,10 @@ def multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (4, 32), "float32") for i, j in T.grid(4, 32): - B = T.allocate((4, 32), "float32", scope="global") - C = T.allocate((4, 32), "float32", scope="global") + B_data = T.allocate((4, 32), "float32", scope="global") + B = T.buffer_decl(shape=(4, 32), dtype="float32", data=B_data) + C_data = T.allocate((4, 32), "float32", scope="global") + C = T.buffer_decl(shape=(4, 32), dtype="float32", data=C_data) B[i, j] = A[i, j] + 1.0 C[i, j] = A[i, j] + B[i, j] D[i, j] = C[i, j] * 2.0 @@ -143,8 +151,10 @@ def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: T.preflattened_buffer(D, (4, 32), "float32", data=D.data) for i, j in T.grid(4, 32): - B = T.allocate([128], "float32", "global") - C = T.allocate([128], "float32", "global") + B_data = T.allocate([128], "float32", "global") + B = T.buffer_decl(shape=[128], dtype="float32", data=B_data) + C_data = T.allocate([128], "float32", "global") + C = T.buffer_decl(shape=[128], dtype="float32", data=C_data) B[i * 32 + j] = A[i * 32 + j] + 1.0 C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] D[i * 32 + j] = C[i * 32 + j] * 2.0 @@ -155,7 +165,8 @@ def strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in T.serial(4): - B = T.allocate([4, 17], "float32", "global") + B_data = T.allocate([4, 17], "float32", "global") + B = T.buffer_decl(shape=[4, 17], dtype="float32", data=B_data) B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) for i1, j in T.grid(4, 16): B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 @@ -170,7 +181,8 @@ def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) for i0 in T.serial(0, 4): - B_new = T.allocate([68], "float32", "global") + B_new_data = T.allocate([68], "float32", "global") + B_new = T.buffer_decl(shape=[68], dtype="float32", data=B_new_data) for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index b96afb6a0941..548f3bc8d1d2 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -145,12 +145,14 @@ def test_vthread_simplified(): def before_func(): vthread = T.env_thread("vthread") T.launch_thread(vthread, 4) - B = T.allocate([4], "int32", "shared") + B_data = T.allocate([4], "int32", scope="shared") + B = T.buffer_decl([4], "int32", data=B_data, scope="shared") B[0:4] = T.broadcast(vthread, 4) @T.prim_func def expected_func(): - B = T.allocate([16], "int32", "shared") + B_data = T.allocate([16], "int32", scope="shared") + B = T.buffer_decl([16], "int32", data=B_data, scope="shared") # The indices for B should each be a single Ramp node, and # should not be the sum of a Ramp and Broadcast node. B[0 * 4 : 0 * 4 + 4] = T.broadcast(0, 4) @@ -172,12 +174,14 @@ def test_vthread_vectorized(): def before_func(): vthread = T.env_thread("vthread") T.launch_thread(vthread, 4) - B = T.allocate([4], "int32", "shared") + B_data = T.allocate([4], "int32", "shared") + B = T.buffer_decl([4], "int32", data=B_data, scope="shared") B[0:4] = T.broadcast(vthread, 4) @T.prim_func def expected_func(): - B = T.allocate([4], "int32x4", "shared") + B_data = T.allocate([4], "int32x4", "shared") + B = T.buffer_decl([4], "int32x4", data=B_data, scope="shared") B[0 * 4 / 4] = T.broadcast(0, 4) B[1 * 4 / 4] = T.broadcast(1, 4) B[2 * 4 / 4] = T.broadcast(2, 4) diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py index 6f557ba09d43..f8f3e3a5aced 100644 --- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py +++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py @@ -54,7 +54,8 @@ def transformed_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in T.serial(0, 16): - B_new = T.allocate([1, 16], "float32", "global") + B_new_data = T.allocate([1, 16], "float32", "global") + B_new = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_new_data) for j in T.serial(0, 16): B_new[0, j] = A[i, j] + 1.0 for j in T.serial(0, 16): @@ -96,7 +97,8 @@ def transformed_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.allocate([1, 16], "float32", "local") + B_data = T.allocate([1, 16], "float32", "local") + B = T.buffer_decl(shape=[1, 16], dtype="float32", scope="local", data=B_data) for j in range(0, 16): B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): @@ -131,7 +133,8 @@ def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B = T.allocate([m], "float32", "global") + B_data = T.allocate([m], "float32", "global") + B = T.buffer_decl(shape=[m], dtype="float32", data=B_data) for j in range(0, m): B[j] = A[i, j] + 1.0 for j in range(0, m): @@ -204,8 +207,10 @@ def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - B = T.allocate((32,), "float32", "global") - C = T.allocate((32,), "float32", "global") + B_data = T.allocate((32,), "float32", "global") + B = T.buffer_decl(shape=(32,), dtype="float32", data=B_data) + C_data = T.allocate((32,), "float32", "global") + C = T.buffer_decl(shape=(32,), dtype="float32", data=C_data) B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 @@ -240,7 +245,8 @@ def transformed_strided_buffer_func( ) -> None: # body for i0 in T.serial(4): - B = T.allocate([4, 17], "float32", "global") + B_data = T.allocate([4, 17], "float32", "global") + B = T.buffer_decl(shape=[4, 17], dtype="float32", data=B_data) B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) for i1, j in T.grid(4, 16): B_1[i1, j] = A[i0 * 4 + i1, j] + T.float32(1) diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py index fd08f7e2249a..bfa132d4cecf 100644 --- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -36,9 +36,9 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo blockIdx_x = T.env_thread("blockIdx.x") # body T.launch_thread(blockIdx_x, 64) - conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") - PadInput_shared = T.allocate([768], "float32", "shared") - weight_shared = T.allocate([4096], "float32", "shared") + conv2d_transpose_nhwc_local = T.decl_buffer([8], "float32", scope="local") + PadInput_shared = T.decl_buffer([768], "float32", scope="shared") + weight_shared = T.decl_buffer([4096], "float32", scope="shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) @@ -67,9 +67,9 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo blockIdx_x = T.env_thread("blockIdx.x") # body T.launch_thread(blockIdx_x, 64) - conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") - PadInput_shared = T.allocate([768], "float32", "shared") - weight_shared = T.allocate([4096], "float32", "shared") + conv2d_transpose_nhwc_local = T.decl_buffer([8], "float32", scope="local") + PadInput_shared = T.decl_buffer([768], "float32", scope="shared") + weight_shared = T.decl_buffer([4096], "float32", scope="shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) @@ -98,9 +98,9 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # body T.launch_thread(blockIdx_x, 64) - conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") - PadInput_shared = T.allocate([768], "float32", "shared") - weight_shared = T.allocate([4096], "float32", "shared") + conv2d_transpose_nhwc_local = T.decl_buffer([8], "float32", scope="local") + PadInput_shared = T.decl_buffer([768], "float32", scope="shared") + weight_shared = T.decl_buffer([4096], "float32", scope="shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index ff59f10c0168..95e2eaed55fa 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -95,7 +95,7 @@ def main(A_param: T.handle, C_param: T.handle): threadIdx_x = T.env_thread("threadIdx.x") T.launch_thread(threadIdx_x, 1) for i in T.serial(0, 100): - B = T.allocate([4], "float32", scope="shared") + B = T.decl_buffer([4], "float32", scope="shared") with T.attr(B.data, "double_buffer_scope", 1): for j in T.serial(0, 4): B[j] = A[4 * i + j] @@ -142,7 +142,7 @@ def main(): A_data: T.Ptr[T.int32] = T.call_extern("dummy_extern_function", dtype="handle") # and a buffer is backed by that pointer, - A = T.buffer_decl([1], dtype="float32", data=A_data) + A = T.decl_buffer([1], dtype="float32", data=A_data) T.evaluate(A[0]) # then the call to StorageFlatten would result in an exception diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index b7cb75594997..581afef88942 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -654,14 +654,16 @@ def test_access_in_let_value(): @T.prim_func def func(A: T.Buffer[(8,), "float32"]): for i in range(8): - B = T.allocate((1,), "float32", "global") + B_data = T.allocate((1,), "float32", "global") + B = T.buffer_decl(shape=[1], dtype="float32", data=B_data) B[0] = 3.14 x: T.float32 = T.exp(B[0], dtype="float32") A[i] = (x + 1.0) / (x - 1.0) @T.prim_func def func_rewritten(A: T.Buffer[(8,), "float32"]) -> None: - B = T.allocate((1,), "float32", "global") + B_data = T.allocate((1,), "float32", "global") + B = T.buffer_decl(shape=[1], dtype="float32", data=B_data) for i in range(8): B[0] = 3.14 x: T.float32 = T.exp(B[0], dtype="float32") diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index 6dba694e45ac..3a638ba45122 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -117,16 +117,19 @@ class before: @T.prim_func def main(): for i in T.unroll(2): - with T.allocate([16], "float32", "global") as buf: + with T.allocate([16], "float32", "global") as buf_data: + buf = T.buffer_decl(shape=[16], dtype="float32", data=buf_data) buf[0] = 0.0 @tvm.script.ir_module class expected: @T.prim_func def main(): - with T.allocate([16], "float32", "global") as buf1: + with T.allocate([16], "float32", "global") as buf1_data: + buf1 = T.buffer_decl(shape=[16], dtype="float32", data=buf1_data) buf1[0] = 0.0 - with T.allocate([16], "float32", "global") as buf2: + with T.allocate([16], "float32", "global") as buf2_data: + buf2 = T.buffer_decl(shape=[16], dtype="float32", data=buf2_data) buf2[0] = 0.0 after = tvm.tir.transform.UnrollLoop()(before) diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index f67148189d8c..265e6fe5d5d5 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -316,12 +316,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_7 = T.allocate([157323], "int16", "global") + PaddedInput_7 = T.decl_buffer([157323], "int16") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7 = T.allocate([64], "int32", "global") + Conv2dOutput_7 = T.decl_buffer([64], "int32") for ff_3 in T.serial(0, 64): Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): @@ -336,7 +336,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [200704], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - tensor_2 = T.allocate([200704], "uint8", "global") + tensor_2 = T.decl_buffer([200704], "uint8") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): @@ -356,9 +356,9 @@ def run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) __tvm_meta__ = None # fmt: on @@ -436,11 +436,11 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16") # body - PaddedInput_1 = T.allocate([379456], "int16", "global") + PaddedInput_1 = T.decl_buffer([379456], "int16") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): - Conv2dOutput_1 = T.allocate([64], "int32", "global") + Conv2dOutput_1 = T.decl_buffer([64], "int32") for ff_1 in T.serial(0, 64): Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): @@ -457,11 +457,11 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [1440000], dtype="int32") # body - PaddedInput_2 = T.allocate([360000], "int16", "global") + PaddedInput_2 = T.decl_buffer([360000], "int16") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): - Conv2dOutput_2 = T.allocate([64], "int32", "global") + Conv2dOutput_2 = T.decl_buffer([64], "int32") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): Conv2dOutput_2[ff_2] = 0 @@ -480,11 +480,11 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [1440000], dtype="uint8") # body - PaddedInput_3 = T.allocate([360000], "int16", "global") + PaddedInput_3 = T.decl_buffer([360000], "int16") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): - Conv2dOutput_3 = T.allocate([64], "int32", "global") + Conv2dOutput_3 = T.decl_buffer([64], "int32") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): Conv2dOutput_3[ff_3] = 0 @@ -504,11 +504,11 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: sid_6 = T.allocate([5760000], "int8", "global") sid_7 = T.allocate([720000], "int8", "global") sid_8 = T.allocate([720000], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6.data, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: @@ -519,11 +519,11 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [360000], dtype="int16") # body - PaddedInput = T.allocate([360000], "int16", "global") + PaddedInput = T.decl_buffer([360000], "int16") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): - Conv2dOutput = T.allocate([64], "int32", "global") + Conv2dOutput = T.decl_buffer([64], "int32") for ff in T.serial(0, 64): Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index 60360ecade70..52880e40cbee 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -128,12 +128,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_7 = T.allocate([157323], "int16", "global") + PaddedInput_7 = T.decl_buffer([157323], "int16") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7 = T.allocate([64], "int32", "global") + Conv2dOutput_7 = T.decl_buffer([64], "int32") for ff_3 in T.serial(0, 64): Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): @@ -148,7 +148,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - tensor_2 = T.allocate([200704], "uint8", "global") + tensor_2 = T.decl_buffer([200704], "uint8") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): @@ -168,9 +168,9 @@ def run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) __tvm_meta__ = None # fmt: on @@ -220,14 +220,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_8 = T.allocate([215296], "int16", "global") + PaddedInput_8 = T.decl_buffer([215296], "int16") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_8 in T.parallel(0, 3136): - dummy_allocate = T.allocate([1], "int32", "global") + dummy_allocate = T.decl_buffer([1], "int32") for ax3_outer_4 in T.serial(0, 3): - Conv2dOutput_8 = T.allocate([64], "int32", "global") + Conv2dOutput_8 = T.decl_buffer([64], "int32") for ff_4 in T.serial(0, 64): Conv2dOutput_8[ff_4] = 0 for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): @@ -261,14 +261,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_8 = T.allocate([215296], "int16", "global") + PaddedInput_8 = T.decl_buffer([215296], "int16") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): - dummy_allocate = T.allocate([1], "int32", "global") + dummy_allocate = T.decl_buffer([1], "int32") for ax3_outer_4 in T.serial(0, 3): - Conv2dOutput_8 = T.allocate([64], "int32", "global") + Conv2dOutput_8 = T.decl_buffer([64], "int32") for ff_4 in T.serial(0, 64): Conv2dOutput_8[ff_4] = 0 for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): @@ -394,12 +394,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place placeholder_21 = T.match_buffer(placeholder_18, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_3 = T.match_buffer(T_cast_2, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput = T.allocate([200704], "int16", "global") + PaddedInput = T.decl_buffer([200704], "int16") for i0_i1_fused in T.serial(0, 56): for i2, i3 in T.grid(56, 64): PaddedInput[(((i0_i1_fused*3584) + (i2*64)) + i3)] = placeholder_19[(((i0_i1_fused*3584) + (i2*64)) + i3)] for ax0_ax1_fused_ax2_fused in T.serial(0, 3136): - Conv2dOutput = T.allocate([64], "int32", "global") + Conv2dOutput = T.decl_buffer([64], "int32") for ff in T.serial(0, 64): Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): @@ -416,12 +416,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla placeholder_27 = T.match_buffer(placeholder_24, [96], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_5 = T.match_buffer(T_cast_4, [153], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_1 = T.allocate([150528], "int16", "global") + PaddedInput_1 = T.decl_buffer([150528], "int16") for i0_i1_fused_1 in T.serial(0, 28): for i2_1, i3_1 in T.grid(28, 192): PaddedInput_1[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)] = placeholder_25[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)] for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 784): - Conv2dOutput_1 = T.allocate([1], "int32", "global") + Conv2dOutput_1 = T.decl_buffer([1], "int32") for ax3_1 in T.serial(0, 96): Conv2dOutput_1[0] = 0 for rc_1 in T.serial(0, 192): @@ -435,7 +435,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - tensor_2 = T.allocate([200704], "uint8", "global") + tensor_2 = T.decl_buffer([200704], "uint8") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): @@ -455,12 +455,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placehol placeholder_35 = T.match_buffer(placeholder_32, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_9 = T.match_buffer(T_cast_8, [121], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_2 = T.allocate([150528], "int16", "global") + PaddedInput_2 = T.decl_buffer([150528], "int16") for i0_i1_fused_2 in T.serial(0, 28): for i2_2, i3_2 in T.grid(28, 192): PaddedInput_2[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)] = placeholder_33[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 784): - Conv2dOutput_2 = T.allocate([64], "int32", "global") + Conv2dOutput_2 = T.decl_buffer([64], "int32") for ff_1 in T.serial(0, 64): Conv2dOutput_2[ff_1] = 0 for rc_2 in T.serial(0, 192): @@ -475,7 +475,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_1 placeholder_37 = T.match_buffer(placeholder_36, [150528], dtype="uint8", elem_offset=0, align=64, offset_factor=1) T_cast_11 = T.match_buffer(T_cast_10, [249], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - tensor_3 = T.allocate([150528], "uint8", "global") + tensor_3 = T.decl_buffer([150528], "uint8") for ax0_ax1_fused_6 in T.serial(0, 28): for ax2_6 in T.serial(0, 28): for ax3_outer_init_1, ax3_inner_init_1 in T.grid(3, 64): @@ -495,12 +495,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed placeholder_43 = T.match_buffer(placeholder_40, [32], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_13 = T.match_buffer(T_cast_12, [89], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_3 = T.allocate([150528], "int16", "global") + PaddedInput_3 = T.decl_buffer([150528], "int16") for i0_i1_fused_3 in T.serial(0, 28): for i2_3, i3_3 in T.grid(28, 192): PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] = placeholder_41[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 784): - Conv2dOutput_3 = T.allocate([1], "int32", "global") + Conv2dOutput_3 = T.decl_buffer([1], "int32") for ax3_5 in T.serial(0, 32): Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): @@ -516,12 +516,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(pla placeholder_49 = T.match_buffer(placeholder_46, [16], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_15 = T.match_buffer(T_cast_14, [73], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_4 = T.allocate([150528], "int16", "global") + PaddedInput_4 = T.decl_buffer([150528], "int16") for i0_i1_fused_4 in T.serial(0, 28): for i2_4, i3_4 in T.grid(28, 192): PaddedInput_4[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)] = placeholder_47[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)] for ax0_ax1_fused_ax2_fused_4 in T.serial(0, 784): - Conv2dOutput_4 = T.allocate([1], "int32", "global") + Conv2dOutput_4 = T.decl_buffer([1], "int32") for ax3_6 in T.serial(0, 16): Conv2dOutput_4[0] = 0 for rc_4 in T.serial(0, 192): @@ -537,12 +537,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed placeholder_55 = T.match_buffer(placeholder_52, [32], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_17 = T.match_buffer(T_cast_16, [89], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_5 = T.allocate([14400], "int16", "global") + PaddedInput_5 = T.decl_buffer([14400], "int16") for i0_i1_fused_5 in T.serial(0, 30): for i2_5, i3_5 in T.grid(30, 16): PaddedInput_5[(((i0_i1_fused_5*480) + (i2_5*16)) + i3_5)] = T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), placeholder_53[((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_5 in T.serial(0, 784): - Conv2dOutput_5 = T.allocate([1], "int32", "global") + Conv2dOutput_5 = T.decl_buffer([1], "int32") for ax3_7 in T.serial(0, 32): Conv2dOutput_5[0] = 0 for ry, rx, rc_5 in T.grid(3, 3, 16): @@ -558,12 +558,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed placeholder_61 = T.match_buffer(placeholder_58, [128], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_19 = T.match_buffer(T_cast_18, [185], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_6 = T.allocate([86400], "int16", "global") + PaddedInput_6 = T.decl_buffer([86400], "int16") for i0_i1_fused_6 in T.serial(0, 30): for i2_6, i3_6 in T.grid(30, 96): PaddedInput_6[(((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6)] = T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), placeholder_59[((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_6 in T.serial(0, 784): - Conv2dOutput_6 = T.allocate([64], "int32", "global") + Conv2dOutput_6 = T.decl_buffer([64], "int32") for ax3_outer_3 in T.serial(0, 2): for ff_2 in T.serial(0, 64): Conv2dOutput_6[ff_2] = 0 @@ -581,12 +581,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_7 = T.allocate([157323], "int16", "global") + PaddedInput_7 = T.decl_buffer([157323], "int16") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7 = T.allocate([64], "int32", "global") + Conv2dOutput_7 = T.decl_buffer([64], "int32") for ff_3 in T.serial(0, 64): Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): @@ -603,12 +603,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_8 = T.allocate([215296], "int16", "global") + PaddedInput_8 = T.decl_buffer([215296], "int16") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): - Conv2dOutput_8 = T.allocate([64], "int32", "global") + Conv2dOutput_8 = T.decl_buffer([64], "int32") for ax3_outer_4 in T.serial(0, 3): for ff_4 in T.serial(0, 64): Conv2dOutput_8[ff_4] = 0 @@ -638,21 +638,21 @@ def run_model(input: T.handle, output: T.handle) -> None: sid_25 = T.allocate([25088], "int8", "global") sid_26 = T.allocate([25088], "int8", "global") sid_31 = T.allocate([25088], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, sid_7.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_7.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_6.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", sid_6.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_5.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d", sid_5.data, sid_4.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_cast", sid_4.data, sid_3.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", sid_3.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_2.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_3.data, T.lookup_param("p9", dtype="handle"), T.lookup_param("p10", dtype="handle"), sid_20.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", sid_20.data, T.lookup_param("p11", dtype="handle"), T.lookup_param("p12", dtype="handle"), sid_19.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", sid_3.data, T.lookup_param("p13", dtype="handle"), T.lookup_param("p14", dtype="handle"), sid_26.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", sid_26.data, T.lookup_param("p15", dtype="handle"), T.lookup_param("p16", dtype="handle"), sid_25.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast_1", sid_4.data, sid_32.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", sid_32.data, T.lookup_param("p17", dtype="handle"), T.lookup_param("p18", dtype="handle"), sid_31.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_concatenate", sid_2.data, sid_19.data, sid_25.data, sid_31.data, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, sid_7, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_7, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_6, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", sid_6, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_5, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d", sid_5, sid_4, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast", sid_4, sid_3, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", sid_3, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_2, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_3, T.lookup_param("p9", dtype="handle"), T.lookup_param("p10", dtype="handle"), sid_20, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", sid_20, T.lookup_param("p11", dtype="handle"), T.lookup_param("p12", dtype="handle"), sid_19, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", sid_3, T.lookup_param("p13", dtype="handle"), T.lookup_param("p14", dtype="handle"), sid_26, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", sid_26, T.lookup_param("p15", dtype="handle"), T.lookup_param("p16", dtype="handle"), sid_25, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast_1", sid_4, sid_32, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", sid_32, T.lookup_param("p17", dtype="handle"), T.lookup_param("p18", dtype="handle"), sid_31, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_concatenate", sid_2, sid_19, sid_25, sid_31, output, dtype="int32")) __tvm_meta__ = None # fmt: on @@ -1129,11 +1129,11 @@ def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeh placeholder_5 = T.match_buffer(placeholder_3, [81], dtype="float32") conv2d_NCHWc_1 = T.match_buffer(conv2d_NCHWc, [41], dtype="float32") # body - data_pad = T.allocate([1092], "float32", "global") + data_pad = T.decl_buffer([1092], "float32") for i0_i1_fused_i2_fused, i3, i4 in T.grid(26, 14, 3): data_pad[i0_i1_fused_i2_fused * 42 + i3 * 3 + i4] = T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, placeholder_4[i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39], T.float32(0), dtype="float32") for n_oc_chunk_fused_oh_fused in T.serial(0, 24): - conv2d_NCHWc_global = T.allocate([36], "float32", "global") + conv2d_NCHWc_global = T.decl_buffer([36], "float32") for oc_block_c_init in T.serial(0, 3): conv2d_NCHWc_global[oc_block_c_init] = T.float32(0) for oc_block_c_init in T.serial(0, 3): @@ -1198,15 +1198,15 @@ def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle T_add_1 = T.match_buffer(T_add, [864], dtype="float32") # body for ax0_ax1_fused_ax2_fused in T.serial(0, 72): - T_softmax_norm = T.allocate([12], "float32", "global") - with T.allocate([1], "float32", "global") as T_softmax_maxelem: + T_softmax_norm = T.decl_buffer([12], "float32") + with T.decl_buffer([1], "float32") as T_softmax_maxelem: T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) for k in T.serial(0, 12): T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_11[ax0_ax1_fused_ax2_fused * 12 + k]) - T_softmax_exp = T.allocate([12], "float32", "global") + T_softmax_exp = T.decl_buffer([12], "float32") for i3 in T.serial(0, 12): T_softmax_exp[i3] = T.exp(placeholder_11[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") - T_softmax_expsum = T.allocate([1], "float32", "global") + T_softmax_expsum = T.decl_buffer([1], "float32") T_softmax_expsum[0] = T.float32(0) for k in T.serial(0, 12): T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] @@ -1224,8 +1224,8 @@ def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, T_relu_1 = T.match_buffer(T_relu, [864], dtype="float32") # body for ax1_outer_ax0_outer_fused in T.serial(0, 18): - compute = T.allocate([48], "float32", "global") - with T.allocate([48], "float32", "global") as compute_global: + compute = T.decl_buffer([48], "float32") + with T.decl_buffer([48], "float32") as compute_global: for x_c_init in T.serial(0, 6): compute_global[x_c_init] = T.float32(0) for x_c_init in T.serial(0, 6): @@ -1317,15 +1317,15 @@ def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27 T_add_3 = T.match_buffer(T_add_2, [864], dtype="float32") # body for ax0_ax1_fused_ax2_fused in T.serial(0, 72): - T_softmax_norm = T.allocate([12], "float32", "global") - with T.allocate([1], "float32", "global") as T_softmax_maxelem: + T_softmax_norm = T.decl_buffer([12], "float32") + with T.decl_buffer([1], "float32") as T_softmax_maxelem: T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) for k in T.serial(0, 12): T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_28[ax0_ax1_fused_ax2_fused * 12 + k]) - T_softmax_exp = T.allocate([12], "float32", "global") + T_softmax_exp= T.decl_buffer([12], "float32") for i3 in T.serial(0, 12): T_softmax_exp[i3] = T.exp(placeholder_28[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") - T_softmax_expsum = T.allocate([1], "float32", "global") + T_softmax_expsum = T.decl_buffer([1], "float32") T_softmax_expsum[0] = T.float32(0) for k in T.serial(0, 12): T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] @@ -1359,20 +1359,20 @@ def run_model(data: T.handle, output: T.handle) -> None: sid_22 = T.allocate_const([1], "int8", [1]) sid_23 = T.allocate_const([2,1], "int8", [3456]) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_23.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_8.data, T.cast(T.lookup_param("p0", dtype="handle"), "handle"), sid_7.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_7.data, sid_6.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", data_buffer.data, sid_12.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_12.data, T.cast(T.lookup_param("p1", dtype="handle"), "handle"), sid_11.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_11.data, sid_10.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add_add_multiply_add", sid_6.data, sid_10.data, T.cast(T.lookup_param("p2", dtype="handle"), "handle"), T.cast(T.lookup_param("p3", dtype="handle"), "handle"), T.cast(T.lookup_param("p4", dtype="handle"), "handle"), sid_5.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", sid_5.data, sid_4.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_4.data, T.cast(T.lookup_param("p5", dtype="handle"), "handle"), sid_3.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_3.data, sid_2.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", sid_5.data, sid_20.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_20.data, T.cast(T.lookup_param("p6", dtype="handle"), "handle"), sid_19.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_19.data, sid_18.data, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add", sid_2.data, sid_18.data, output_buffer.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_23, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_8, T.cast(T.lookup_param("p0", dtype="handle"), "handle"), sid_7, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_7, sid_6, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", data_buffer.data, sid_12, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_12, T.cast(T.lookup_param("p1", dtype="handle"), "handle"), sid_11, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_11, sid_10, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add_add_multiply_add", sid_6, sid_10, T.cast(T.lookup_param("p2", dtype="handle"), "handle"), T.cast(T.lookup_param("p3", dtype="handle"), "handle"), T.cast(T.lookup_param("p4", dtype="handle"), "handle"), sid_5, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", sid_5, sid_4, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_4, T.cast(T.lookup_param("p5", dtype="handle"), "handle"), sid_3, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_3, sid_2, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", sid_5, sid_20, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_20, T.cast(T.lookup_param("p6", dtype="handle"), "handle"), sid_19, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_19, sid_18, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add", sid_2, sid_18, output_buffer.data, dtype="int32")) # fmt: on diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index e6d123118757..fdda400a779f 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -98,12 +98,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1) T.preflattened_buffer(T_cast_21, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_7 = T.allocate([157323], "int16", "global") + PaddedInput_7_data = T.allocate([157323], "int16", "global") + PaddedInput_7 = T.buffer_decl(shape=[157323], dtype="int16", data=PaddedInput_7_data) for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7 = T.allocate([64], "int32", "global") + Conv2dOutput_7_data = T.allocate([64], "int32", "global") + Conv2dOutput_7 = T.buffer_decl(shape=[64], dtype="int32", data=Conv2dOutput_7_data) for ff_3 in T.serial(0, 64): Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): @@ -120,7 +122,8 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1) T.preflattened_buffer(T_cast_7, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - tensor_2 = T.allocate([200704], "uint8", "global") + tensor_2_data = T.allocate([200704], "uint8", "global") + tensor_2 = T.buffer_decl(shape=[200704], dtype="uint8", data=tensor_2_data) for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): @@ -140,9 +143,9 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) # fmt: on @@ -299,11 +302,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") T.preflattened_buffer(T_cast_5, [215], dtype="int16") # body - PaddedInput_1 = T.allocate([379456], "int16", "global") + PaddedInput_1_data = T.allocate([379456], "int16", "global") + PaddedInput_1 = T.buffer_decl(shape=[379456], dtype="int16", data=PaddedInput_1_data) for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): - Conv2dOutput_1 = T.allocate([64], "int32", "global") + Conv2dOutput_1_data = T.allocate([64], "int32", "global") + Conv2dOutput_1 = T.buffer_decl(shape=[64], dtype="int32", data=Conv2dOutput_1_data) for ff_1 in T.serial(0, 64): Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): @@ -324,11 +329,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_add_1 = T.match_buffer(T_add, [407], dtype="int32") T.preflattened_buffer(T_add_1, [407], dtype="int32") # body - PaddedInput_2 = T.allocate([360000], "int16", "global") + PaddedInput_2_data = T.allocate([360000], "int16", "global") + PaddedInput_2 = T.buffer_decl(shape=[360000], dtype="int16", data=PaddedInput_2_data) for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): - Conv2dOutput_2 = T.allocate([64], "int32", "global") + Conv2dOutput_2_data = T.allocate([64], "int32", "global") + Conv2dOutput_2 = T.buffer_decl(shape=[64], dtype="int32", data=Conv2dOutput_2_data) for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): Conv2dOutput_2[ff_2] = 0 @@ -352,11 +359,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") T.preflattened_buffer(T_cast_7, [407], dtype="uint8") # body - PaddedInput_3 = T.allocate([360000], "int16", "global") + PaddedInput_3_data = T.allocate([360000], "int16", "global") + PaddedInput_3 = T.buffer_decl(shape=[360000], dtype="int16", data=PaddedInput_3_data) for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): - Conv2dOutput_3 = T.allocate([64], "int32", "global") + Conv2dOutput_3_data = T.allocate([64], "int32", "global") + Conv2dOutput_3 = T.buffer_decl(shape=[64], dtype="int32", data=Conv2dOutput_3_data) for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): Conv2dOutput_3[ff_3] = 0 @@ -376,11 +385,11 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: sid_6 = T.allocate([5760000], "int8", "global") sid_7 = T.allocate([720000], "int8", "global") sid_8 = T.allocate([720000], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6.data, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: @@ -395,11 +404,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") T.preflattened_buffer(T_cast_3, [215], dtype="int16") # body - PaddedInput = T.allocate([360000], "int16", "global") + PaddedInput_data = T.allocate([360000], "int16", "global") + PaddedInput = T.buffer_decl([360000], "int16", data=PaddedInput_data) for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): - Conv2dOutput = T.allocate([64], "int32", "global") + Conv2dOutput_data = T.allocate([64], "int32", "global") + Conv2dOutput = T.buffer_decl([64], "int32", data=Conv2dOutput_data) for ff in T.serial(0, 64): Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index 155ff0962def..756b97b0d223 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -48,12 +48,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1) # body - PaddedInput_7 = T.allocate([157323], "int16", "global") + PaddedInput_7 = T.decl_buffer([157323], "int16") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7 = T.allocate([64], "int32", "global") + Conv2dOutput_7 = T.decl_buffer([64], "int32") for ff_3 in T.serial(0, 64): Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): @@ -68,7 +68,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - tensor_2 = T.allocate([200704], "uint8", "global") + tensor_2 = T.decl_buffer([200704], "uint8") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): @@ -88,9 +88,9 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 45ea88f829ec..17622789558d 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -94,12 +94,18 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=64, offset_factor=1) C_1 = T.match_buffer(C, [1024 * 1024], elem_offset=0, align=64, offset_factor=1) # body - packedB = T.allocate([32768], "float32", "global") + packedB_data = T.allocate([32768], "float32", "global") + packedB = T.buffer_decl( + shape=[32768], dtype="float32", scope="global", data=packedB_data + ) for x in T.parallel(0, 32): for y in T.serial(0, 1024): packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B_1[y, T.ramp(x * 32, 1, 32)] for x_outer in T.parallel(0, 32): - C_global = T.allocate([1024], "float32", "global") + C_global_data = T.allocate([1024], "float32", "global") + C_global = T.buffer_decl( + shape=[1024], dtype="float32", scope="global", data=C_global_data + ) for y_outer in T.serial(0, 32): for x_c_init in T.serial(0, 32): C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast(T.float32(0), 32) @@ -953,11 +959,24 @@ def func( ty = T.env_thread("threadIdx.y") tz = T.env_thread("threadIdx.z") T.launch_thread(bz, 196) - Conv_wmma_accumulator = T.allocate([2048], "float32", "wmma.accumulator") - Apad_shared = T.allocate([12288], "float16", "shared") - W_shared = T.allocate([12288], "float16", "shared") - Apad_shared_wmma_matrix_a = T.allocate([512], "float16", "wmma.matrix_a") - W_shared_wmma_matrix_b = T.allocate([1024], "float16", "wmma.matrix_b") + Conv_wmma_accumulator_data = T.allocate([2048], "float32", "wmma.accumulator") + Conv_wmma_accumulator = T.buffer_decl( + shape=[2048], dtype="float32", scope="wmma.accumulator", data=Conv_wmma_accumulator_data + ) + Apad_shared_data = T.allocate([12288], "float16", "shared") + Apad_shared = T.buffer_decl( + shape=[12288], dtype="float16", scope="shared", data=Apad_shared_data + ) + W_shared_data = T.allocate([12288], "float16", "shared") + W_shared = T.buffer_decl(shape=[12288], dtype="float16", scope="shared", data=W_shared_data) + Apad_shared_wmma_matrix_a_data = T.allocate([512], "float16", "wmma.matrix_a") + Apad_shared_wmma_matrix_a = T.buffer_decl( + shape=[512], dtype="float16", scope="wmma.matrix_a", data=Apad_shared_wmma_matrix_a_data + ) + W_shared_wmma_matrix_b_data = T.allocate([1024], "float16", "wmma.matrix_b") + W_shared_wmma_matrix_b = T.buffer_decl( + shape=[1024], dtype="float16", scope="wmma.matrix_b", data=W_shared_wmma_matrix_b_data + ) T.launch_thread(bx, 2) T.launch_thread(by, 4) T.launch_thread(ty, 4) @@ -2479,7 +2498,8 @@ def vthread_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") + B_data = T.allocate([16], "float32", "local") + B = T.buffer_decl(shape=[16], dtype="float32", scope="local", data=B_data) for j in range(16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + T.float32(1) for j in range(16): @@ -2792,11 +2812,13 @@ def B(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (10), "int32") B = T.alloc_buffer((10), "int32") - K1 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K1_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K1 = T.buffer_decl(shape=[10], dtype="int32", data=K1_data) for x in T.serial(0, 10): B[x] = A[x] + K1[x] - K2 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K2_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K2 = T.buffer_decl(shape=[10], dtype="int32", data=K2_data) for x in T.serial(0, 10): B[x] = B[x] + K2[x] @@ -2812,7 +2834,8 @@ def constant(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (10), "int32") C = T.match_buffer(c, (10), "int32") B = T.alloc_buffer((10), "int32") - K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + K = T.buffer_decl(shape=[10], dtype="int32", data=K_data) for x in T.serial(0, 10): B[x] = A[x] + K[x] @@ -2961,7 +2984,8 @@ def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.han placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [200704], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body - tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) + tensor_2_data = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) + tensor_2 = T.buffer_decl(shape=[200704], dtype="uint8", scope="global", data=tensor_2_data) for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): @@ -2987,7 +3011,8 @@ def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128 * 128], dtype="float32") for i in T.serial(0, 128): T.launch_thread(threadIdx_x, 128) - reduce_temp0 = T.allocate([1], "float32", "local") + reduce_temp0_data = T.allocate([1], "float32", "local") + reduce_temp0 = T.buffer_decl(shape=[1], dtype="float32", scope="local", data=reduce_temp0_data) with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0.data, threadIdx_x, dtype="handle")) @@ -3002,7 +3027,8 @@ def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128 * 128], dtype="float32") for i in T.serial(0, 128): T.launch_thread(threadIdx_x, 128) - reduce_temp0 = T.allocate([1], "float32", "local") + reduce_temp0_data = T.allocate([1], "float32", "local") + reduce_temp0 = T.buffer_decl(shape=[1], dtype="float32", scope="local", data=reduce_temp0_data) with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0.data, threadIdx_x, dtype="handle")) @@ -3149,7 +3175,8 @@ def func_T_ptr_let_statement( def func_T_ptr_allocate(): @T.prim_func def func_T_ptr_allocate() -> None: - A = T.allocate([1024], "float32", "global") + A_data = T.allocate([1024], "float32", "global") + A = T.buffer_decl(shape=[1024], dtype="float32", scope="global", data=A_data) A[0] = 0.0 return func_T_ptr_allocate @@ -3240,8 +3267,10 @@ def string_annotation_of_special_chars(): def pointer_type(): @T.prim_func def func_with_ptr_type_annotations(x: T.Ptr[T.int32], y: T.Ptr[T.int32, "shared"]): - xx = T.allocate([16], "int32", "global") - yy = T.allocate([16], "int32", "shared") + xx_data = T.allocate([16], "int32", "global") + xx = T.buffer_decl(shape=[16], dtype="int32", scope="global", data=xx_data) + yy_data = T.allocate([16], "int32", "shared") + yy = T.buffer_decl(shape=[16], dtype="int32", scope="shared", data=yy_data) a: T.Ptr[T.int32] = T.address_of(xx[0], dtype="handle") b: T.Ptr[T.int32, "shared"] = T.address_of(yy[0], dtype="handle") T.evaluate(T.call_extern("copy", a, b, dtype="")) @@ -3313,6 +3342,24 @@ def func(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> return func +def allocate_and_decl_buffer(): + @T.prim_func + def func(A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"]) -> None: + D_data = T.allocate((16,), "float32", "global") + D = T.decl_buffer((16,), "float32", data=D_data) + for i in range(4): + with T.allocate((4,), "float32", "global") as C_data: + C = T.decl_buffer((4,), "float32", data=C_data) + for j in range(4): + C[j] = A[i * 4 + j] + T.float32(1.0) + for j in range(4): + D[j] = C[j] + for j in range(4): + B[i * 4 + j] = D[j] + + return func + + def float_infinity(): @T.prim_func def func( @@ -3374,6 +3421,7 @@ def func( let_expression, void_ptr, decl_buffer, + allocate_and_decl_buffer, float_infinity, )