Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][Test] Add unittests for T2D #12249

Merged
merged 1 commit into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def _find_match_sketch_id(
sketches: List[Schedule],
expected_mod: IRModule,
expected_decision: List[Tuple[str, List[int]]],
*,
debug_mask="all",
) -> Optional[int]:
for sketch_id, sketch in enumerate(sketches):
i = 0
Expand All @@ -53,13 +55,13 @@ def _find_match_sketch_id(
i += 1
if len(new_decisions) != len(expected_decision):
continue
sch = Schedule(mod, debug_mask="all")
sch = Schedule(mod, debug_mask=debug_mask)
Trace(
insts=sketch.trace.insts,
decisions=new_decisions,
).apply_to_schedule(sch, remove_postproc=True)
if structural_equal(sch.mod, expected_mod):
verify_trace_roundtrip(sch=sch, mod=mod)
verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
return sketch_id
return None

Expand All @@ -69,6 +71,8 @@ def check_sketches(
sketches: List[Schedule],
expected_mods: List[IRModule],
expected_decisions: List[List[Tuple[str, List[int]]]],
*,
debug_mask="all",
):
assert len(expected_mods) == len(expected_decisions)
assert len(sketches) == len(expected_mods)
Expand All @@ -79,7 +83,13 @@ def check_sketches(
for expected_id, (expected_mod, expected_decision) in enumerate(
zip(expected_mods, expected_decisions)
):
sketch_id = _find_match_sketch_id(mod, sketches, expected_mod, expected_decision)
sketch_id = _find_match_sketch_id(
mod,
sketches,
expected_mod,
expected_decision,
debug_mask=debug_mask,
)
if sketch_id is None:
raise AssertionError(
f"Expected sketch #{expected_id} doesn't exist in the generated sketches."
Expand Down
162 changes: 162 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,167 @@ def grp_2(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3,
)


def test_cpu_t2d():
# fmt: off
@T.prim_func
def t2d_0(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 6, 6, 512], dtype="float32")
conv2d_transpose_nhwc_global = T.alloc_buffer([1, 8, 8, 256], dtype="float32")
for i0, i1, i2, i3 in T.grid(1, 6, 6, 512):
with T.block("PadInput"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 5 and 1 <= i2_1 and i2_1 < 5, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1 in T.grid(1, 1, 2, 8, 1, 4, 1, 4):
for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8):
with T.block("conv2d_transpose_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_0 + i0_1_1 + i0_2)
h = T.axis.spatial(8, i1_0 * 8 + i1_1_1 * 2 + i1_2 * 2 + i1_3)
w = T.axis.spatial(8, i2_0 * 4 + i2_1_1 * 4 + i2_2 * 4 + i2_3)
co = T.axis.spatial(256, i3_0 * 32 + i3_1_1 * 8 + i3_2 * 8 + i3_3)
rh = T.axis.reduce(4, i4_0 * 2 + i4_1)
rw = T.axis.reduce(4, i5_0 * 2 + i5_1)
rc = T.axis.reduce(512, i6_0 * 8 + i6_1)
T.reads(PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], weight[3 - rh, 3 - rw, rc, co])
T.writes(conv2d_transpose_nhwc_global[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_transpose_nhwc_global[n, h, w, co] = T.float32(0)
conv2d_transpose_nhwc_global[n, h, w, co] = conv2d_transpose_nhwc_global[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], T.float32(0), dtype="float32") * weight[3 - rh, 3 - rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 2, 4, 8):
with T.block("conv2d_transpose_nhwc_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(8, i1_1_1 * 2 + ax1)
v2 = T.axis.spatial(8, i2_0 * 4 + ax2)
v3 = T.axis.spatial(256, i3_0 * 32 + i3_1_1 * 8 + ax3)
T.reads(conv2d_transpose_nhwc_global[v0, v1, v2, v3])
T.writes(conv2d_transpose_nhwc[v0, v1, v2, v3])
conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_global[v0, v1, v2, v3]
@T.prim_func
def t2d_1(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 6, 6, 512], dtype="float32")
conv2d_transpose_nhwc_global = T.alloc_buffer([1, 8, 8, 256], dtype="float32")
for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 2, 8):
for ax0, ax1, ax2, ax3 in T.grid(1, 6, 4, 512):
with T.block("PadInput"):
i0, i1 = T.axis.remap("SS", [ax0, ax1])
i2 = T.axis.spatial(6, i2_0 * 2 + ax2)
i3 = T.axis.spatial(512, ax3)
T.reads(inputs[i0, i1 - 1, i2 - 1, i3])
T.writes(PadInput[i0, i1, i2, i3])
PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 5 and 1 <= i2 and i2 < 5, inputs[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
for i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8):
with T.block("conv2d_transpose_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
h = T.axis.spatial(8, i1_0 * 8 + i1_1 * 2 + i1_2 * 2 + i1_3)
w = T.axis.spatial(8, i2_0 * 4 + i2_1 * 4 + i2_2 * 4 + i2_3)
co = T.axis.spatial(256, i3_0 * 32 + i3_1 * 8 + i3_2 * 8 + i3_3)
rh = T.axis.reduce(4, i4_0 * 2 + i4_1)
rw = T.axis.reduce(4, i5_0 * 2 + i5_1)
rc = T.axis.reduce(512, i6_0 * 8 + i6_1)
T.reads(PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], weight[3 - rh, 3 - rw, rc, co])
T.writes(conv2d_transpose_nhwc_global[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_transpose_nhwc_global[n, h, w, co] = T.float32(0)
conv2d_transpose_nhwc_global[n, h, w, co] = conv2d_transpose_nhwc_global[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], T.float32(0), dtype="float32") * weight[3 - rh, 3 - rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 8, 4, 32):
with T.block("conv2d_transpose_nhwc_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(8, i2_0 * 4 + ax2)
v3 = T.axis.spatial(256, i3_0 * 32 + ax3)
T.reads(conv2d_transpose_nhwc_global[v0, v1, v2, v3])
T.writes(conv2d_transpose_nhwc[v0, v1, v2, v3])
conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_global[v0, v1, v2, v3]
@T.prim_func
def t2d_2(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64})
for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 2, 8, 1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8):
with T.block("conv2d_transpose_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
h = T.axis.spatial(8, i1_0 * 8 + i1_1 * 2 + i1_2 * 2 + i1_3)
w = T.axis.spatial(8, i2_0 * 4 + i2_1 * 4 + i2_2 * 4 + i2_3)
co = T.axis.spatial(256, i3_0 * 32 + i3_1 * 8 + i3_2 * 8 + i3_3)
rh = T.axis.reduce(4, i4_0 * 2 + i4_1)
rw = T.axis.reduce(4, i5_0 * 2 + i5_1)
rc = T.axis.reduce(512, i6_0 * 8 + i6_1)
T.reads(inputs[n, (h + rh) // 2 - 1, (w + rw) // 2 - 1, rc], weight[3 - rh, 3 - rw, rc, co])
T.writes(conv2d_transpose_nhwc[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_transpose_nhwc[n, h, w, co] = T.float32(0)
conv2d_transpose_nhwc[n, h, w, co] = conv2d_transpose_nhwc[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, T.if_then_else(1 <= (h + rh) // 2 and (h + rh) // 2 < 5 and 1 <= (w + rw) // 2 and (w + rw) // 2 < 5, inputs[n, (h + rh) // 2 - 1, (w + rw) // 2 - 1, rc], T.float32(0), dtype="float32"), T.float32(0), dtype="float32") * weight[3 - rh, 3 - rw, rc, co]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [1, 4, 1, 2]),
("SamplePerfectTile", [2, 1, 1, 4]),
("SamplePerfectTile", [8, 4, 1, 8]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [64, 8]),
("SampleCategorical", 2),
("SampleComputeLocation", -1),
]
decision_1 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [1, 4, 1, 2]),
("SamplePerfectTile", [2, 1, 1, 4]),
("SamplePerfectTile", [8, 4, 1, 8]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [64, 8]),
("SampleCategorical", 2),
("SampleComputeLocation", 3),
]
decision_2 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [1, 4, 1, 2]),
("SamplePerfectTile", [2, 1, 1, 4]),
("SamplePerfectTile", [8, 4, 1, 8]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [64, 8]),
("SampleCategorical", 3),
("SampleComputeLocation", -2),
]
mod = create_te_workload("T2D", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[t2d_0, t2d_1, t2d_2],
expected_decisions=[decision_0, decision_1, decision_2],
debug_mask=0,
)


if __name__ == "__main__":
test_cpu_c1d()
test_cpu_c2d()
Expand All @@ -1384,3 +1545,4 @@ def grp_2(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3,
test_cpu_dil()
test_cpu_gmm()
test_cpu_grp()
test_cpu_t2d()
92 changes: 92 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,97 @@ def grp_0(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3,
)


def test_cuda_t2d():
# fmt: off
@T.prim_func
def t2d_0(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":64})
conv2d_transpose_nhwc_local = T.alloc_buffer([1, 8, 8, 256], dtype="float32", scope="local")
PadInput_shared = T.alloc_buffer([1, 6, 6, 512], dtype="float32", scope="shared")
weight_shared = T.alloc_buffer([4, 4, 512, 256], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(256, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(2, thread="vthread.x"):
for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(1, thread="threadIdx.x"):
for i4_0, i5_0, i6_0 in T.grid(4, 1, 16):
for ax0_ax1_ax2_ax3_fused in T.serial((i4_0 % 2 + 1) // 2 * 96 + 96):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * ((i4_0 % 2 + 1) // 2 + 1)) // 96)
v2 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32)
v3 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 32)
T.reads(inputs[v0, v1 - 1, v2 - 1, v3])
T.writes(PadInput_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":2})
PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 5 and 1 <= v2 and v2 < 5, inputs[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32")
for ax0_ax1_ax2_ax3_fused in T.serial(2048):
with T.block("weight_shared"):
v0 = T.axis.spatial(4, i4_0 * -1 + 3)
v1 = T.axis.spatial(4, ax0_ax1_ax2_ax3_fused // 512)
v2 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 512 // 16)
v3 = T.axis.spatial(256, i0_0_i1_0_i2_0_i3_0_fused % 16 * 16 + ax0_ax1_ax2_ax3_fused % 16)
T.reads(weight[v0, v1, v2, v3])
T.writes(weight_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":4})
weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 4, 1, 2, 1, 8, 1, 4, 8, 1, 1, 2, 1):
with T.block("conv2d_transpose_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_4)
h = T.axis.spatial(8, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 64 * 2 + i1_3)
w = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 * 2 + i2_3 * 2 + i2_4)
co = T.axis.spatial(256, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 16 * 16 + i0_1_i1_1_i2_1_i3_1_fused * 8 + i3_3)
rh = T.axis.reduce(4, i4_0 + i4_1 + i4_2)
rw = T.axis.reduce(4, i5_0 * 4 + i5_1 * 4 + i5_2)
rc = T.axis.reduce(512, i6_0 * 32 + i6_1 * 8 + i6_2)
T.reads(PadInput_shared[n, (h + rh) // 2, (w + rw) // 2, rc], weight_shared[3 - rh, 3 - rw, rc, co])
T.writes(conv2d_transpose_nhwc_local[n, h, w, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
with T.init():
conv2d_transpose_nhwc_local[n, h, w, co] = T.float32(0)
conv2d_transpose_nhwc_local[n, h, w, co] = conv2d_transpose_nhwc_local[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, PadInput_shared[n, (h + rh) // 2, (w + rw) // 2, rc], T.float32(0), dtype="float32") * weight_shared[3 - rh, 3 - rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 2, 2, 8):
with T.block("conv2d_transpose_nhwc_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_fused // 64 * 2 + ax1)
v2 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 * 2 + ax2)
v3 = T.axis.spatial(256, i0_0_i1_0_i2_0_i3_0_fused % 16 * 16 + i0_1_i1_1_i2_1_i3_1_fused * 8 + ax3)
T.reads(conv2d_transpose_nhwc_local[v0, v1, v2, v3])
T.writes(conv2d_transpose_nhwc[v0, v1, v2, v3])
conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_local[v0, v1, v2, v3]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [4, 1, 1, 2, 1]),
("SamplePerfectTile", [4, 1, 1, 1, 2]),
("SamplePerfectTile", [16, 2, 1, 8, 1]),
("SamplePerfectTile", [4, 1, 1]),
("SamplePerfectTile", [1, 1, 4]),
("SamplePerfectTile", [16, 4, 8]),
("SampleCategorical", 1),
("SampleCategorical", 3),
("SampleCategorical", 2),
]
mod = create_te_workload("T2D", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[t2d_0],
expected_decisions=[decision_0],
debug_mask=0,
)


if __name__ == "__main__":
test_cuda_c1d()
test_cuda_c2d()
Expand All @@ -751,3 +842,4 @@ def grp_0(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3,
test_cuda_dil()
test_cuda_gmm()
test_cuda_grp()
test_cuda_t2d()