From 3a4906603075e1c5c39e9a80bdd2bc13676c2794 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 14 Apr 2021 05:08:48 +0800 Subject: [PATCH] [TensorIR] [Script] adding support for opaque block (#7829) * change complete tag * add parsing support for opaque block * address and add testcase * address * address --- include/tvm/tir/stmt.h | 9 +++++ python/tvm/script/scope_handler.py | 14 ++++++-- src/tir/ir/script/script_complete.cc | 15 ++++++-- .../unittest/test_tvmscript_complete.py | 28 +++++++++++++-- .../unittest/test_tvmscript_roundtrip.py | 36 +++++++++++++++++++ 5 files changed, 95 insertions(+), 7 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 84c27498740a1..09317680f6396 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1316,6 +1316,15 @@ constexpr const char* fragment_layout = "fragment_layout"; * \brief Mark that the kernel is hand threaded and doesn't need syncs inserted */ constexpr const char* hand_threaded = "hand_threaded"; + +/*! + * \brief Mark whether the script-completer need to fill in missing access region + * during script parsing. + * \note The result should be a integer mask with range [0, 4). + * if (mask & 1) the read region should be detected, + * if (mask & 2) the write region should be detected. + */ +constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access"; /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index c7d841abc36d5..913ba5a6e3386 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -282,6 +282,13 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None): if block_info.writes else [] ) + + region_detect_mask: int = (block_info.reads is None) | ( + (block_info.writes is None) << 1 + ) + annotations = {} if block_info.annotations is None else block_info.annotations + if region_detect_mask != 0: + annotations["tir.script_parsing_detect_access"] = region_detect_mask inner = tvm.tir.Block( block_iters, reads, @@ -291,14 +298,17 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None): block_info.init, block_info.alloc_buffers, block_info.match_buffers, - block_info.annotations, + annotations, span, ) # create block var iter binding values: List[PrimExpr] if not block_info.iter_bindings: values = self.context.loop_stack[-2].copy() - if len(values) == 0: + if len(block_iters) == 0: + # It is an opaque block without any bindings + values = [] + elif len(values) == 0: values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters) elif len(values) != len(block_iters): self.context.report_error( diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index a42b5ea5b3a07..c15b3bb47bf4d 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -76,8 +76,15 @@ class ScriptCompleter : public StmtMutator { for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->erase(alloc_buffer->data); } + // Get access detection mask + // 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write + int mask = 0; + auto it = op->annotations.find(attr::script_parsing_detect_access); + if (it != op->annotations.end()) { + mask = Downcast((*it).second)->value; + } // ignore root block or blocks which already has reads/writes regions - if (block->reads.empty() || block->writes.empty()) { + if (mask != 0) { if (op->iter_vars.empty()) { // non-root opaque block is not allowed CHECK(is_root_block) @@ -93,8 +100,10 @@ class ScriptCompleter : public StmtMutator { << "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or " "direct access by buffer data. Please annotation the access region manually"; auto n = CopyOnWrite(block.operator->()); - if (n->reads.empty()) n->reads = reads; - if (n->writes.empty()) n->writes = writes; + if (mask & 1) n->reads = reads; + if (mask & 2) n->writes = writes; + n->annotations = op->annotations; + n->annotations.erase(attr::script_parsing_detect_access); return Block(n); } else { return std::move(block); diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 012ccc4b86285..76d4d9b980436 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -81,6 +81,22 @@ def func_with_opaque_block(a: ty.handle, b: ty.handle, c: ty.handle) -> None: C[vi, vj] = B[vi, vj] + tir.float32(1) +@tvm.script.tir +def func_with_part_access_region(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([]) as []: + with tir.block([128, 128]) as [vi, vj]: + tir.reads(A[vi, vj]) + B[vi, vj] = A[vi, vj] + tir.float32(1) + + with tir.block([128, 128]) as [vi, vj]: + tir.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + tir.float32(1) + + def test_complete_matmul(): func = matmul A, B, C = [func.buffer_map[x] for x in func.params] @@ -124,8 +140,7 @@ def test_complete_matmul_original(): tvm.ir.assert_structural_equal(block2.writes, [access_C]) -def test_complete_with_root(): - func = elementwise_with_root +def _check_elementwise(func): A, B, C = [func.buffer_map[x] for x in func.params] block1 = func.body.block.body[0].body.body.block @@ -154,6 +169,14 @@ def test_complete_with_root(): ) +def test_complete_with_root(): + _check_elementwise(elementwise_with_root) + + +def test_complete_part_region(): + _check_elementwise(func_with_part_access_region) + + def test_complete_opaque_block_error(): def render(e): pass @@ -172,3 +195,4 @@ def render(e): test_complete_matmul_original() test_complete_with_root() test_complete_opaque_block_error() + test_complete_part_region() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index bd36b79d7f4eb..cbdcbbb2e6f0a 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2828,6 +2828,41 @@ def test_block_elements(): assert block.annotations["attr_key"] == "attr_value" +@tvm.script.tir +def opaque_block(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + + for i in range(0, 16): + for j in range(0, 16): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, j]) + A[i, j] = tir.float32(0) + with tir.block([]): + tir.reads([A[i, 0:16]]) + tir.writes([B[i, 0:16]]) + for j in range(0, 16): + B[i, j] = A[i, j] + + +def test_opaque_block(): + func = opaque_block + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + root_block = rt_func.body.block + assert isinstance(root_block, tir.stmt.Block) + assert isinstance(root_block.body, tir.stmt.For) + assert isinstance(root_block.body.body[0], tir.stmt.For) + assert isinstance(root_block.body.body[0].body, tir.stmt.BlockRealize) + assert isinstance(root_block.body.body[0].body.block, tir.stmt.Block) + assert len(root_block.body.body[0].body.block.iter_vars) == 0 + assert isinstance(root_block.body.body[1], tir.stmt.BlockRealize) + assert isinstance(root_block.body.body[1].block, tir.stmt.Block) + assert len(root_block.body.body[1].block.iter_vars) == 0 + + if __name__ == "__main__": test_opt_gemm_normalize() test_opt_gemm_mod_host() @@ -2842,3 +2877,4 @@ def test_block_elements(): test_predicate() test_for_thread_binding() test_block_elements() + test_opaque_block()