Skip to content

Commit

Permalink
[TIR] Fix of inter thread reduction with shared memory prefetch (#16406)
Browse files Browse the repository at this point in the history
This is a fix of `LowerCrossThreadReduction`: The pass will remove all the loops with thread bind under the inter thread reduction block, which will introduce some issues when we meet the case where there could be other non-reduction blocks under the reduction thread.

Before removing a thread-bound loop, check if the block(s) under this loop has reduction block var. If the block(s) under have reduction do not have any reduction block var, it means that block is not reduction, and therefore this thread-bound loop should be kept. Otherwise, we remove the thread-bound loop as usual.

related discussion: https://discuss.tvm.apache.org/t/missing-thread-bind-loops-under-block-reduction-when-transformed-with-tir/16232/6
  • Loading branch information
LeiWang1999 authored Jan 21, 2024
1 parent ccca00a commit 81f8690
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 1 deletion.
37 changes: 36 additions & 1 deletion src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,37 @@ class BufferReplacer : private StmtExprMutator {
*/
class InThreadReducerMaker : private StmtMutator {
public:
/*!
* \brief Visitor class to collect all reduction block variables under a loop.
*/
class UnderLoopReductionBlockVarCollector : public StmtVisitor {
public:
/*!
* \brief Check if the given statement has any reduction blocks.
* \param stmt The statement to check.
* \return True if the statement has reduction blocks, false otherwise.
*/
static bool CheckHasReductionBlocks(const Stmt& stmt) {
UnderLoopReductionBlockVarCollector collector;
collector(stmt);
return collector.reduction_block_vars_.size() > 0;
}

private:
void VisitStmt_(const BlockNode* block) final {
Array<IterVar> iter_vars = block->iter_vars;
for (const IterVar& iter_var : block->iter_vars) {
if (iter_var->iter_type == kCommReduce) {
reduction_block_vars_.push_back(iter_var);
}
}
StmtVisitor::VisitStmt_(block);
}

/*! \brief the map from thread tag to its extent */
Array<IterVar> reduction_block_vars_;
};

static Optional<Stmt> Make(const BlockRealizeNode* src_realize,
Optional<BlockRealize> tgt_realize, Stmt stmt) {
return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt));
Expand All @@ -220,7 +251,11 @@ class InThreadReducerMaker : private StmtMutator {
if (Optional<For> opt_res = Downcast<Optional<For>>(StmtMutator::VisitStmt_(loop))) {
For res = opt_res.value();
if (res->thread_binding.defined()) {
return res->body;
UnderLoopReductionBlockVarCollector collector;
if (!res->body.defined() || collector.CheckHasReductionBlocks(res)) {
return res->body;
}
return std::move(res);
} else {
return std::move(res);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,225 @@ def lowered_single_reduction_loop_with_block_predicate(
)


@T.prim_func
def spatial_reduction_with_shared_prefetch(
A: T.Buffer((128, 150528), "float32"),
B: T.Buffer((128, 150528), "float32"),
C: T.Buffer((128, 128), "float32"),
):
C_local = T.alloc_buffer((128, 128), scope="local")
A_shared = T.alloc_buffer((128, 150528), scope="shared")
B_shared = T.alloc_buffer((128, 150528), scope="shared")
for ax0_0_ax1_0_fused in T.thread_binding(256, thread="blockIdx.x"):
for ax0_1_ax1_1_fused in T.thread_binding(64, thread="threadIdx.y"):
for ax2_1_1_fused in T.thread_binding(2, thread="threadIdx.x"):
for ax2_0 in range(392):
for ax0_ax1_fused_0 in range(6):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"):
for ax0_ax1_fused_3 in T.serial(4):
with T.block("A_shared"):
v0 = T.axis.spatial(
128,
ax0_0_ax1_0_fused // 16 * 8
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 8
+ ax0_ax1_fused_2 * 4
+ ax0_ax1_fused_3
)
// 384,
)
v1 = T.axis.spatial(
150528,
ax2_0 * 384
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 8
+ ax0_ax1_fused_2 * 4
+ ax0_ax1_fused_3
)
% 384,
)
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for ax0_ax1_fused_0 in range(6):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"):
for ax0_ax1_fused_3 in T.serial(4):
with T.block("B_shared"):
v0 = T.axis.spatial(
128,
ax0_0_ax1_0_fused % 16 * 8
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 8
+ ax0_ax1_fused_2 * 4
+ ax0_ax1_fused_3
)
// 384,
)
v1 = T.axis.spatial(
150528,
ax2_0 * 384
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 8
+ ax0_ax1_fused_2 * 4
+ ax0_ax1_fused_3
)
% 384,
)
T.reads(B[v0, v1])
T.writes(B_shared[v0, v1])
B_shared[v0, v1] = B[v0, v1]
for ax2_1_0 in range(192):
with T.block("B"):
v0 = T.axis.spatial(
128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8
)
v1 = T.axis.spatial(
128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8
)
v2 = T.axis.reduce(150528, ax2_0 * 384 + ax2_1_0 * 2 + ax2_1_1_fused)
T.reads(A_shared[v0, v2], B_shared[v1, v2])
T.writes(C_local[v0, v1])
with T.init():
C_local[v0, v1] = T.float32(0)
C_local[v0, v1] = C_local[v0, v1] + A_shared[v0, v2] * B_shared[v1, v2]
with T.block("C_local"):
v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8)
v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8)
T.reads(C_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_local[v0, v1]


@T.prim_func
def lowered_spatial_reduction_with_shared_prefetch(
A: T.Buffer((128, 150528), "float32"),
B: T.Buffer((128, 150528), "float32"),
C: T.Buffer((128, 128), "float32"),
):
C_local = T.alloc_buffer((128, 128), scope="local")
A_shared = T.alloc_buffer((128, 150528), scope="shared")
B_shared = T.alloc_buffer((128, 150528), scope="shared")
cross_thread_C_local = T.alloc_buffer((1,), strides=(1,), scope="local")
in_thread_C_local = T.alloc_buffer((1,), strides=(1,), scope="local")
for ax0_0_ax1_0_fused in T.thread_binding(256, thread="blockIdx.x"):
for ax0_1_ax1_1_fused in T.thread_binding(64, thread="threadIdx.y"):
for ax2_1_1_fused in T.thread_binding(2, thread="threadIdx.x"):
with T.block("B_in_thread_init"):
T.reads()
T.writes(in_thread_C_local[0])
in_thread_C_local[0] = T.float32(0)
for ax2_0 in range(392):
for ax0_ax1_fused_0 in range(6):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"):
for ax0_ax1_fused_3 in range(4):
with T.block("A_shared"):
v0 = T.axis.spatial(
128,
ax0_0_ax1_0_fused // 16 * 8
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 8
+ ax0_ax1_fused_2 * 4
+ ax0_ax1_fused_3
)
// 384,
)
v1 = T.axis.spatial(
150528,
ax2_0 * 384
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 8
+ ax0_ax1_fused_2 * 4
+ ax0_ax1_fused_3
)
% 384,
)
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for ax0_ax1_fused_0 in range(6):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"):
for ax0_ax1_fused_3 in range(4):
with T.block("B_shared"):
v0 = T.axis.spatial(
128,
ax0_0_ax1_0_fused % 16 * 8
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 8
+ ax0_ax1_fused_2 * 4
+ ax0_ax1_fused_3
)
// 384,
)
v1 = T.axis.spatial(
150528,
ax2_0 * 384
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 8
+ ax0_ax1_fused_2 * 4
+ ax0_ax1_fused_3
)
% 384,
)
T.reads(B[v0, v1])
T.writes(B_shared[v0, v1])
B_shared[v0, v1] = B[v0, v1]
for ax2_1_0 in range(192):
with T.block("B_in_thread"):
v0 = T.axis.spatial(
128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8
)
v1 = T.axis.spatial(
128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8
)
v2 = T.axis.reduce(150528, ax2_0 * 384 + ax2_1_0 * 2 + ax2_1_1_fused)
T.reads(A_shared[v0, v2], B_shared[v1, v2])
T.writes(in_thread_C_local[0])
in_thread_C_local[0] = (
in_thread_C_local[0] + A_shared[v0, v2] * B_shared[v1, v2]
)
with T.block("B_cross_thread"):
T.reads(in_thread_C_local[0])
T.writes(cross_thread_C_local[0])
T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
)
T.tvm_thread_allreduce(
T.uint32(1),
in_thread_C_local[0],
T.bool(True),
cross_thread_C_local[0],
ax2_1_1_fused,
)
with T.block("B_write_back"):
v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8)
v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8)
T.reads(cross_thread_C_local[0])
T.writes(C_local[v0, v1])
C_local[v0, v1] = cross_thread_C_local[0]
for tx in T.thread_binding(2, thread="threadIdx.x"):
with T.block("C_local"):
v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8)
v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8)
T.where(tx == 0)
T.reads(C_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_local[v0, v1]


@T.prim_func
def spatial_reduction_loop_predicate(A: T.Buffer((2, 32), "float32"), B: T.Buffer((2,), "float32")):
for i_0 in range(1):
Expand Down Expand Up @@ -1588,6 +1807,13 @@ def test_with_block_predicate():
_check(with_block_predicate, lowered_with_block_predicate)


def test_single_reduction_loop_with_shared_memory_prefetch():
_check(
spatial_reduction_with_shared_prefetch,
lowered_spatial_reduction_with_shared_prefetch,
)


def test_single_reduction_loop_with_block_predicate():
_check(
single_reduction_loop_with_block_predicate,
Expand Down

0 comments on commit 81f8690

Please sign in to comment.