Skip to content

Commit

Permalink
[BugFix][TIR] Fix primitive Bind for init-inside blocks (#9359)
Browse files Browse the repository at this point in the history
* [BugFix][TIR] Fix primitive `Bind` for init-inside blocks

* fix python black error
  • Loading branch information
MasterJH5574 authored Oct 25, 2021
1 parent bdb311b commit aa38997
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/tir/schedule/primitive/for_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind
runtime::ThreadScope thread_scope) {
PreOrderVisit(loop, [&](const ObjectRef& node) {
if (const auto* realize = node.as<BlockRealizeNode>()) {
// If this block doesn't have corresponding StmtSRef in the schedule state, it must be a block
// inside `tir.init()`. We don't check the condition for such blocks.
if (!self->stmt2ref.count(realize->block.get())) {
return false;
}
CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef<BlockRealize>(realize),
thread_scope);
}
Expand Down
46 changes: 46 additions & 0 deletions tests/python/unittest/test_tir_schedule_for_kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,44 @@ def opaque_block(a: T.handle) -> None:
A[i + 1] = A[i + 1] + A[i]


@T.prim_func
def block_inside_init(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128], dtype="float32")
B = T.match_buffer(b, [128, 128], dtype="float32")
for i in T.serial(0, 128):
with T.block("outer"):
vi = T.axis.S(128, i)
with T.init():
for j in T.serial(0, 128):
with T.block("init"):
vj = T.axis.S(128, j)
B[vi, vj] = 0.0
for k in T.serial(0, 128):
for j in T.serial(0, 128):
with T.block("inner"):
vj, vk = T.axis.remap("SR", [j, k])
B[vi, vj] = B[vi, vj] + A[vi, vj, vk]


@T.prim_func
def thread_bound_block_inside_init(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128], dtype="float32")
B = T.match_buffer(b, [128, 128], dtype="float32")
for i in T.thread_binding(0, 128, thread="threadIdx.x"):
with T.block("outer"):
vi = T.axis.S(128, i)
with T.init():
for j in T.serial(0, 128):
with T.block("init"):
vj = T.axis.S(128, j)
B[vi, vj] = 0.0
for k in T.serial(0, 128):
for j in T.serial(0, 128):
with T.block("inner"):
vj, vk = T.axis.remap("SR", [j, k])
B[vi, vj] = B[vi, vj] + A[vi, vj, vk]


# pylint: enable=no-member,invalid-name,unused-variable


Expand Down Expand Up @@ -361,5 +399,13 @@ def test_bind_after_bind():
verify_trace_roundtrip(s, mod=element_wise)


def test_block_inside_init():
s = tir.Schedule(block_inside_init, debug_mask="all")
(i,) = s.get_loops(s.get_block("outer"))
s.bind(i, "threadIdx.x")
tvm.ir.assert_structural_equal(s.mod["main"], thread_bound_block_inside_init)
verify_trace_roundtrip(s, mod=block_inside_init)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit aa38997

Please sign in to comment.