Skip to content

Commit

Permalink
[TensorIR] [Script] adding support for opaque block (apache#7829)
Browse files Browse the repository at this point in the history
* change complete tag

* add parsing support for opaque block

* address and add testcase

* address

* address
  • Loading branch information
Hzfengsy authored and Trevor Morris committed May 6, 2021
1 parent 333ef2b commit 3a49066
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 7 deletions.
9 changes: 9 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,15 @@ constexpr const char* fragment_layout = "fragment_layout";
* \brief Mark that the kernel is hand threaded and doesn't need syncs inserted
*/
constexpr const char* hand_threaded = "hand_threaded";

/*!
* \brief Mark whether the script-completer need to fill in missing access region
* during script parsing.
* \note The result should be a integer mask with range [0, 4).
* if (mask & 1) the read region should be detected,
* if (mask & 2) the write region should be detected.
*/
constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/script/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,13 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
if block_info.writes
else []
)

region_detect_mask: int = (block_info.reads is None) | (
(block_info.writes is None) << 1
)
annotations = {} if block_info.annotations is None else block_info.annotations
if region_detect_mask != 0:
annotations["tir.script_parsing_detect_access"] = region_detect_mask
inner = tvm.tir.Block(
block_iters,
reads,
Expand All @@ -291,14 +298,17 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
block_info.init,
block_info.alloc_buffers,
block_info.match_buffers,
block_info.annotations,
annotations,
span,
)
# create block var iter binding
values: List[PrimExpr]
if not block_info.iter_bindings:
values = self.context.loop_stack[-2].copy()
if len(values) == 0:
if len(block_iters) == 0:
# It is an opaque block without any bindings
values = []
elif len(values) == 0:
values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters)
elif len(values) != len(block_iters):
self.context.report_error(
Expand Down
15 changes: 12 additions & 3 deletions src/tir/ir/script/script_complete.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,15 @@ class ScriptCompleter : public StmtMutator {
for (const auto& alloc_buffer : op->alloc_buffers) {
buffer_var_map_->erase(alloc_buffer->data);
}
// Get access detection mask
// 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write
int mask = 0;
auto it = op->annotations.find(attr::script_parsing_detect_access);
if (it != op->annotations.end()) {
mask = Downcast<IntImm>((*it).second)->value;
}
// ignore root block or blocks which already has reads/writes regions
if (block->reads.empty() || block->writes.empty()) {
if (mask != 0) {
if (op->iter_vars.empty()) {
// non-root opaque block is not allowed
CHECK(is_root_block)
Expand All @@ -93,8 +100,10 @@ class ScriptCompleter : public StmtMutator {
<< "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or "
"direct access by buffer data. Please annotation the access region manually";
auto n = CopyOnWrite(block.operator->());
if (n->reads.empty()) n->reads = reads;
if (n->writes.empty()) n->writes = writes;
if (mask & 1) n->reads = reads;
if (mask & 2) n->writes = writes;
n->annotations = op->annotations;
n->annotations.erase(attr::script_parsing_detect_access);
return Block(n);
} else {
return std::move(block);
Expand Down
28 changes: 26 additions & 2 deletions tests/python/unittest/test_tvmscript_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ def func_with_opaque_block(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
C[vi, vj] = B[vi, vj] + tir.float32(1)


@tvm.script.tir
def func_with_part_access_region(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])

with tir.block([]) as []:
with tir.block([128, 128]) as [vi, vj]:
tir.reads(A[vi, vj])
B[vi, vj] = A[vi, vj] + tir.float32(1)

with tir.block([128, 128]) as [vi, vj]:
tir.writes(C[vi, vj])
C[vi, vj] = B[vi, vj] + tir.float32(1)


def test_complete_matmul():
func = matmul
A, B, C = [func.buffer_map[x] for x in func.params]
Expand Down Expand Up @@ -124,8 +140,7 @@ def test_complete_matmul_original():
tvm.ir.assert_structural_equal(block2.writes, [access_C])


def test_complete_with_root():
func = elementwise_with_root
def _check_elementwise(func):
A, B, C = [func.buffer_map[x] for x in func.params]

block1 = func.body.block.body[0].body.body.block
Expand Down Expand Up @@ -154,6 +169,14 @@ def test_complete_with_root():
)


def test_complete_with_root():
_check_elementwise(elementwise_with_root)


def test_complete_part_region():
_check_elementwise(func_with_part_access_region)


def test_complete_opaque_block_error():
def render(e):
pass
Expand All @@ -172,3 +195,4 @@ def render(e):
test_complete_matmul_original()
test_complete_with_root()
test_complete_opaque_block_error()
test_complete_part_region()
36 changes: 36 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2828,6 +2828,41 @@ def test_block_elements():
assert block.annotations["attr_key"] == "attr_value"


@tvm.script.tir
def opaque_block(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
B = tir.match_buffer(b, (16, 16), "float32")

for i in range(0, 16):
for j in range(0, 16):
with tir.block([]):
tir.reads([])
tir.writes(A[i, j])
A[i, j] = tir.float32(0)
with tir.block([]):
tir.reads([A[i, 0:16]])
tir.writes([B[i, 0:16]])
for j in range(0, 16):
B[i, j] = A[i, j]


def test_opaque_block():
func = opaque_block
rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
tvm.ir.assert_structural_equal(func, rt_func)

root_block = rt_func.body.block
assert isinstance(root_block, tir.stmt.Block)
assert isinstance(root_block.body, tir.stmt.For)
assert isinstance(root_block.body.body[0], tir.stmt.For)
assert isinstance(root_block.body.body[0].body, tir.stmt.BlockRealize)
assert isinstance(root_block.body.body[0].body.block, tir.stmt.Block)
assert len(root_block.body.body[0].body.block.iter_vars) == 0
assert isinstance(root_block.body.body[1], tir.stmt.BlockRealize)
assert isinstance(root_block.body.body[1].block, tir.stmt.Block)
assert len(root_block.body.body[1].block.iter_vars) == 0


if __name__ == "__main__":
test_opt_gemm_normalize()
test_opt_gemm_mod_host()
Expand All @@ -2842,3 +2877,4 @@ def test_block_elements():
test_predicate()
test_for_thread_binding()
test_block_elements()
test_opaque_block()

0 comments on commit 3a49066

Please sign in to comment.