Skip to content

Commit

Permalink
[MetaSchedule][Test] Add unittests for T2D (#12249)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Jul 31, 2022
1 parent 1d1fc08 commit a842449
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 3 deletions.
16 changes: 13 additions & 3 deletions python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def _find_match_sketch_id(
sketches: List[Schedule],
expected_mod: IRModule,
expected_decision: List[Tuple[str, List[int]]],
*,
debug_mask="all",
) -> Optional[int]:
for sketch_id, sketch in enumerate(sketches):
i = 0
Expand All @@ -53,13 +55,13 @@ def _find_match_sketch_id(
i += 1
if len(new_decisions) != len(expected_decision):
continue
sch = Schedule(mod, debug_mask="all")
sch = Schedule(mod, debug_mask=debug_mask)
Trace(
insts=sketch.trace.insts,
decisions=new_decisions,
).apply_to_schedule(sch, remove_postproc=True)
if structural_equal(sch.mod, expected_mod):
verify_trace_roundtrip(sch=sch, mod=mod)
verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
return sketch_id
return None

Expand All @@ -69,6 +71,8 @@ def check_sketches(
sketches: List[Schedule],
expected_mods: List[IRModule],
expected_decisions: List[List[Tuple[str, List[int]]]],
*,
debug_mask="all",
):
assert len(expected_mods) == len(expected_decisions)
assert len(sketches) == len(expected_mods)
Expand All @@ -79,7 +83,13 @@ def check_sketches(
for expected_id, (expected_mod, expected_decision) in enumerate(
zip(expected_mods, expected_decisions)
):
sketch_id = _find_match_sketch_id(mod, sketches, expected_mod, expected_decision)
sketch_id = _find_match_sketch_id(
mod,
sketches,
expected_mod,
expected_decision,
debug_mask=debug_mask,
)
if sketch_id is None:
raise AssertionError(
f"Expected sketch #{expected_id} doesn't exist in the generated sketches."
Expand Down
162 changes: 162 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,167 @@ def grp_2(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3,
)


def test_cpu_t2d():
# fmt: off
@T.prim_func
def t2d_0(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 6, 6, 512], dtype="float32")
conv2d_transpose_nhwc_global = T.alloc_buffer([1, 8, 8, 256], dtype="float32")
for i0, i1, i2, i3 in T.grid(1, 6, 6, 512):
with T.block("PadInput"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 5 and 1 <= i2_1 and i2_1 < 5, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1 in T.grid(1, 1, 2, 8, 1, 4, 1, 4):
for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8):
with T.block("conv2d_transpose_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_0 + i0_1_1 + i0_2)
h = T.axis.spatial(8, i1_0 * 8 + i1_1_1 * 2 + i1_2 * 2 + i1_3)
w = T.axis.spatial(8, i2_0 * 4 + i2_1_1 * 4 + i2_2 * 4 + i2_3)
co = T.axis.spatial(256, i3_0 * 32 + i3_1_1 * 8 + i3_2 * 8 + i3_3)
rh = T.axis.reduce(4, i4_0 * 2 + i4_1)
rw = T.axis.reduce(4, i5_0 * 2 + i5_1)
rc = T.axis.reduce(512, i6_0 * 8 + i6_1)
T.reads(PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], weight[3 - rh, 3 - rw, rc, co])
T.writes(conv2d_transpose_nhwc_global[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_transpose_nhwc_global[n, h, w, co] = T.float32(0)
conv2d_transpose_nhwc_global[n, h, w, co] = conv2d_transpose_nhwc_global[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], T.float32(0), dtype="float32") * weight[3 - rh, 3 - rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 2, 4, 8):
with T.block("conv2d_transpose_nhwc_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(8, i1_1_1 * 2 + ax1)
v2 = T.axis.spatial(8, i2_0 * 4 + ax2)
v3 = T.axis.spatial(256, i3_0 * 32 + i3_1_1 * 8 + ax3)
T.reads(conv2d_transpose_nhwc_global[v0, v1, v2, v3])
T.writes(conv2d_transpose_nhwc[v0, v1, v2, v3])
conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_global[v0, v1, v2, v3]
@T.prim_func
def t2d_1(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 6, 6, 512], dtype="float32")
conv2d_transpose_nhwc_global = T.alloc_buffer([1, 8, 8, 256], dtype="float32")
for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 2, 8):
for ax0, ax1, ax2, ax3 in T.grid(1, 6, 4, 512):
with T.block("PadInput"):
i0, i1 = T.axis.remap("SS", [ax0, ax1])
i2 = T.axis.spatial(6, i2_0 * 2 + ax2)
i3 = T.axis.spatial(512, ax3)
T.reads(inputs[i0, i1 - 1, i2 - 1, i3])
T.writes(PadInput[i0, i1, i2, i3])
PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 5 and 1 <= i2 and i2 < 5, inputs[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
for i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8):
with T.block("conv2d_transpose_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
h = T.axis.spatial(8, i1_0 * 8 + i1_1 * 2 + i1_2 * 2 + i1_3)
w = T.axis.spatial(8, i2_0 * 4 + i2_1 * 4 + i2_2 * 4 + i2_3)
co = T.axis.spatial(256, i3_0 * 32 + i3_1 * 8 + i3_2 * 8 + i3_3)
rh = T.axis.reduce(4, i4_0 * 2 + i4_1)
rw = T.axis.reduce(4, i5_0 * 2 + i5_1)
rc = T.axis.reduce(512, i6_0 * 8 + i6_1)
T.reads(PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], weight[3 - rh, 3 - rw, rc, co])
T.writes(conv2d_transpose_nhwc_global[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_transpose_nhwc_global[n, h, w, co] = T.float32(0)
conv2d_transpose_nhwc_global[n, h, w, co] = conv2d_transpose_nhwc_global[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, PadInput[n, (h + rh) // 2, (w + rw) // 2, rc], T.float32(0), dtype="float32") * weight[3 - rh, 3 - rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 8, 4, 32):
with T.block("conv2d_transpose_nhwc_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(8, i2_0 * 4 + ax2)
v3 = T.axis.spatial(256, i3_0 * 32 + ax3)
T.reads(conv2d_transpose_nhwc_global[v0, v1, v2, v3])
T.writes(conv2d_transpose_nhwc[v0, v1, v2, v3])
conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_global[v0, v1, v2, v3]
@T.prim_func
def t2d_2(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64})
for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 2, 8, 1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8):
with T.block("conv2d_transpose_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
h = T.axis.spatial(8, i1_0 * 8 + i1_1 * 2 + i1_2 * 2 + i1_3)
w = T.axis.spatial(8, i2_0 * 4 + i2_1 * 4 + i2_2 * 4 + i2_3)
co = T.axis.spatial(256, i3_0 * 32 + i3_1 * 8 + i3_2 * 8 + i3_3)
rh = T.axis.reduce(4, i4_0 * 2 + i4_1)
rw = T.axis.reduce(4, i5_0 * 2 + i5_1)
rc = T.axis.reduce(512, i6_0 * 8 + i6_1)
T.reads(inputs[n, (h + rh) // 2 - 1, (w + rw) // 2 - 1, rc], weight[3 - rh, 3 - rw, rc, co])
T.writes(conv2d_transpose_nhwc[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_transpose_nhwc[n, h, w, co] = T.float32(0)
conv2d_transpose_nhwc[n, h, w, co] = conv2d_transpose_nhwc[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, T.if_then_else(1 <= (h + rh) // 2 and (h + rh) // 2 < 5 and 1 <= (w + rw) // 2 and (w + rw) // 2 < 5, inputs[n, (h + rh) // 2 - 1, (w + rw) // 2 - 1, rc], T.float32(0), dtype="float32"), T.float32(0), dtype="float32") * weight[3 - rh, 3 - rw, rc, co]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [1, 4, 1, 2]),
("SamplePerfectTile", [2, 1, 1, 4]),
("SamplePerfectTile", [8, 4, 1, 8]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [64, 8]),
("SampleCategorical", 2),
("SampleComputeLocation", -1),
]
decision_1 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [1, 4, 1, 2]),
("SamplePerfectTile", [2, 1, 1, 4]),
("SamplePerfectTile", [8, 4, 1, 8]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [64, 8]),
("SampleCategorical", 2),
("SampleComputeLocation", 3),
]
decision_2 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [1, 4, 1, 2]),
("SamplePerfectTile", [2, 1, 1, 4]),
("SamplePerfectTile", [8, 4, 1, 8]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [2, 2]),
("SamplePerfectTile", [64, 8]),
("SampleCategorical", 3),
("SampleComputeLocation", -2),
]
mod = create_te_workload("T2D", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[t2d_0, t2d_1, t2d_2],
expected_decisions=[decision_0, decision_1, decision_2],
debug_mask=0,
)


if __name__ == "__main__":
test_cpu_c1d()
test_cpu_c2d()
Expand All @@ -1384,3 +1545,4 @@ def grp_2(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3,
test_cpu_dil()
test_cpu_gmm()
test_cpu_grp()
test_cpu_t2d()
92 changes: 92 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,97 @@ def grp_0(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3,
)


def test_cuda_t2d():
# fmt: off
@T.prim_func
def t2d_0(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":64})
conv2d_transpose_nhwc_local = T.alloc_buffer([1, 8, 8, 256], dtype="float32", scope="local")
PadInput_shared = T.alloc_buffer([1, 6, 6, 512], dtype="float32", scope="shared")
weight_shared = T.alloc_buffer([4, 4, 512, 256], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(256, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(2, thread="vthread.x"):
for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(1, thread="threadIdx.x"):
for i4_0, i5_0, i6_0 in T.grid(4, 1, 16):
for ax0_ax1_ax2_ax3_fused in T.serial((i4_0 % 2 + 1) // 2 * 96 + 96):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * ((i4_0 % 2 + 1) // 2 + 1)) // 96)
v2 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32)
v3 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 32)
T.reads(inputs[v0, v1 - 1, v2 - 1, v3])
T.writes(PadInput_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":2})
PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 5 and 1 <= v2 and v2 < 5, inputs[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32")
for ax0_ax1_ax2_ax3_fused in T.serial(2048):
with T.block("weight_shared"):
v0 = T.axis.spatial(4, i4_0 * -1 + 3)
v1 = T.axis.spatial(4, ax0_ax1_ax2_ax3_fused // 512)
v2 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 512 // 16)
v3 = T.axis.spatial(256, i0_0_i1_0_i2_0_i3_0_fused % 16 * 16 + ax0_ax1_ax2_ax3_fused % 16)
T.reads(weight[v0, v1, v2, v3])
T.writes(weight_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":4})
weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 4, 1, 2, 1, 8, 1, 4, 8, 1, 1, 2, 1):
with T.block("conv2d_transpose_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_4)
h = T.axis.spatial(8, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 64 * 2 + i1_3)
w = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 * 2 + i2_3 * 2 + i2_4)
co = T.axis.spatial(256, i3_4 + i0_0_i1_0_i2_0_i3_0_fused % 16 * 16 + i0_1_i1_1_i2_1_i3_1_fused * 8 + i3_3)
rh = T.axis.reduce(4, i4_0 + i4_1 + i4_2)
rw = T.axis.reduce(4, i5_0 * 4 + i5_1 * 4 + i5_2)
rc = T.axis.reduce(512, i6_0 * 32 + i6_1 * 8 + i6_2)
T.reads(PadInput_shared[n, (h + rh) // 2, (w + rw) // 2, rc], weight_shared[3 - rh, 3 - rw, rc, co])
T.writes(conv2d_transpose_nhwc_local[n, h, w, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
with T.init():
conv2d_transpose_nhwc_local[n, h, w, co] = T.float32(0)
conv2d_transpose_nhwc_local[n, h, w, co] = conv2d_transpose_nhwc_local[n, h, w, co] + T.if_then_else((h + rh) % 2 == 0 and (w + rw) % 2 == 0, PadInput_shared[n, (h + rh) // 2, (w + rw) // 2, rc], T.float32(0), dtype="float32") * weight_shared[3 - rh, 3 - rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 2, 2, 8):
with T.block("conv2d_transpose_nhwc_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_fused // 64 * 2 + ax1)
v2 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 * 2 + ax2)
v3 = T.axis.spatial(256, i0_0_i1_0_i2_0_i3_0_fused % 16 * 16 + i0_1_i1_1_i2_1_i3_1_fused * 8 + ax3)
T.reads(conv2d_transpose_nhwc_local[v0, v1, v2, v3])
T.writes(conv2d_transpose_nhwc[v0, v1, v2, v3])
conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_local[v0, v1, v2, v3]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [4, 1, 1, 2, 1]),
("SamplePerfectTile", [4, 1, 1, 1, 2]),
("SamplePerfectTile", [16, 2, 1, 8, 1]),
("SamplePerfectTile", [4, 1, 1]),
("SamplePerfectTile", [1, 1, 4]),
("SamplePerfectTile", [16, 4, 8]),
("SampleCategorical", 1),
("SampleCategorical", 3),
("SampleCategorical", 2),
]
mod = create_te_workload("T2D", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[t2d_0],
expected_decisions=[decision_0],
debug_mask=0,
)


if __name__ == "__main__":
test_cuda_c1d()
test_cuda_c2d()
Expand All @@ -751,3 +842,4 @@ def grp_0(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3,
test_cuda_dil()
test_cuda_gmm()
test_cuda_grp()
test_cuda_t2d()

0 comments on commit a842449

Please sign in to comment.