Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Handle 'warp_execution' in RewriteCooperativeFetch #11955

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst,
return Downcast<BlockRV>(inst->inputs[0]);
}

/*!
* \brief Parse instruction: sch.annotate(..., attr::warp_execution)
* \param sch The schedule
* \param inst The instruction to be parsed
* \return Whether ths parsing is successful
*/
bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) {
static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
if (!inst->kind.same_as(inst_kind_annotate)) {
return false;
}
ICHECK_EQ(inst->inputs.size(), 2);
ICHECK_EQ(inst->attrs.size(), 1);
String ann_key = Downcast<String>(inst->attrs[0]);
return ann_key == attr::warp_execution;
}

} // namespace tir

namespace meta_schedule {
Expand All @@ -76,14 +93,24 @@ namespace meta_schedule {
class RewriteCooperativeFetchNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
void InitializeWithTuneContext(const TuneContext& context) final {
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("thread_warp_size")) {
this->thread_warp_size_ = v.value()->value;
} else {
TVM_PY_LOG(INFO, context->logging_func) << "'thread_warp_size' is not defined in the target";
}
}

// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;

void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode);

private:
int thread_warp_size_ = -1;
};

bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
Expand All @@ -101,6 +128,10 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
thread_extent_y = new_thread_extent.value()->value;
continue;
}
if (tir::ParseWarpExecutionAnn(sch, inst)) {
thread_extent_x = thread_warp_size_;
continue;
}
Optional<tir::BlockRV> opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane);
if (!opt_block_rv.defined()) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring

import tvm
import tvm.testing
from tvm import tir
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.postproc import RewriteCooperativeFetch
Expand Down Expand Up @@ -99,6 +100,108 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None:
C[v0, v1] = C_local[v0, v1]


@tvm.script.ir_module
class WarpExecutionAfterRewrite:
@T.prim_func
def main(
A: T.Buffer[(512, 512), "float32"],
B: T.Buffer[(512, 512), "float32"],
C: T.Buffer[(512, 512), "float32"],
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"):
for i2_0 in T.serial(0, 1):
for ax0_ax1_fused_0 in T.serial(0, 1024):
for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(
0, 32, thread="threadIdx.x"
):
with T.block("A_shared"):
v0 = T.axis.spatial(
512,
(
ax0_ax1_fused_0 * 256
+ ax0_ax1_fused_1 * 32
+ ax0_ax1_fused_2
)
// 512,
)
v1 = T.axis.spatial(
512,
(
ax0_ax1_fused_0 * 256
+ ax0_ax1_fused_1 * 32
+ ax0_ax1_fused_2
)
% 512,
)
T.reads([A[v0, v1]])
T.writes([A_shared[v0, v1]])
A_shared[v0, v1] = A[v0, v1]
for ax0_ax1_fused_0 in T.serial(0, 32):
for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(
0, 32, thread="threadIdx.x"
):
for ax0_ax1_fused_3 in T.vectorized(0, 2):
with T.block("B_shared"):
v0 = T.axis.spatial(
512,
(
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 64
+ ax0_ax1_fused_2 * 2
+ ax0_ax1_fused_3
)
// 32,
)
v1 = T.axis.spatial(
512,
i0_0_i1_0_fused * 32
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 64
+ ax0_ax1_fused_2 * 2
+ ax0_ax1_fused_3
)
% 32,
)
T.reads([B[v0, v1]])
T.writes([B_shared[v0, v1]])
B_shared[v0, v1] = B[v0, v1]
for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2):
with T.block("C"):
i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
j = T.axis.spatial(
512,
i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4,
)
k = T.axis.reduce(512, i2_0 * 512 + i2_1 * 32 + i2_2)
T.reads([A_shared[i, k], B_shared[k, j]])
T.writes([C_local[i, j]])
T.block_attr({"warp_execution": 1})
with T.init():
C_local[i, j] = T.float32(0)
C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j]
for ax0, ax1 in T.grid(32, 4):
with T.block("C_local"):
v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0)
v1 = T.axis.spatial(
512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1
)
T.reads([C_local[v0, v1]])
T.writes([C[v0, v1]])
C[v0, v1] = C_local[v0, v1]


# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on

Expand Down Expand Up @@ -147,5 +250,51 @@ def test_rewrite_cooperative_fetch():
tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0)


def test_rewrite_warp_execution():
mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512))
target = _target()
ctx = _create_context(mod, target)

sch = tir.Schedule(mod, debug_mask="all")
# fmt: off
# pylint: disable=line-too-long,invalid-name
b0 = sch.get_block(name="C", func_name="main")
b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")
l2, l3, l4 = sch.get_loops(block=b0)
sch.annotate(b0, "warp_execution", 1)
v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16])
l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])
v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2])
l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])
v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32])
l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])
sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)
l31 = sch.fuse(l10, l20)
sch.bind(loop=l31, thread_axis="blockIdx.x")
l32 = sch.fuse(l11, l21)
sch.bind(loop=l32, thread_axis="vthread.x")
l33 = sch.fuse(l12, l22)
sch.bind(loop=l33, thread_axis="threadIdx.y")
b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")
sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)
_, _, _, _, l39, l40 = sch.get_loops(block=b34)
l41 = sch.fuse(l39, l40)
_, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1])
sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)
b44 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)
_, _, _, _, l49, l50 = sch.get_loops(block=b44)
l51 = sch.fuse(l49, l50)
_, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2])
sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)
sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)
# pylint: enable=line-too-long,invalid-name
# fmt: on
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
print(sch.mod["main"].script())
tvm.ir.assert_structural_equal(sch.mod, WarpExecutionAfterRewrite)


if __name__ == "__main__":
test_rewrite_cooperative_fetch()
tvm.testing.main()