Skip to content

Commit

Permalink
[MetaSchedule][Testing] Test search space of conv1d (apache#12032)
Browse files Browse the repository at this point in the history
* [MetaSchedule][Testing] Test search space of conv1d

* Add checks for trace roundtripping
  • Loading branch information
junrushao authored and Mikael Sevenier committed Jul 26, 2022
1 parent 29e0c86 commit b24f762
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 1 deletion.
65 changes: 64 additions & 1 deletion python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
from typing import List
from typing import List, Optional, Tuple

from tvm.ir import IRModule, structural_equal
from tvm.tir import Schedule
from tvm.tir.schedule import Trace
from tvm.tir.schedule.testing import verify_trace_roundtrip


def check_trace(spaces: List[Schedule], expected: List[List[str]]):
Expand All @@ -31,3 +33,64 @@ def check_trace(spaces: List[Schedule], expected: List[List[str]]):
actual_traces.add(str_trace)
assert str_trace in expected_traces, "\n" + str_trace
assert len(expected_traces) == len(actual_traces)


def _find_match_sketch_id(
mod: IRModule,
sketches: List[Schedule],
expected_mod: IRModule,
expected_decision: List[Tuple[str, List[int]]],
) -> Optional[int]:
for sketch_id, sketch in enumerate(sketches):
i = 0
new_decisions = {}
for inst in sketch.trace.insts:
if not inst.kind.name.startswith("Sample"):
continue
assert i < len(expected_decision)
if inst.kind.name == expected_decision[i][0]:
new_decisions[inst] = expected_decision[i][1]
i += 1
if len(new_decisions) != len(expected_decision):
continue
sch = Schedule(mod, debug_mask="all")
Trace(
insts=sketch.trace.insts,
decisions=new_decisions,
).apply_to_schedule(sch, remove_postproc=True)
if structural_equal(sch.mod, expected_mod):
verify_trace_roundtrip(sch=sch, mod=mod)
return sketch_id
return None


def check_sketches(
mod: IRModule,
sketches: List[Schedule],
expected_mods: List[IRModule],
expected_decisions: List[List[Tuple[str, List[int]]]],
):
assert len(expected_mods) == len(expected_decisions)
assert len(sketches) == len(expected_mods)
expected_mods = [
IRModule({"main": m}) if not isinstance(m, IRModule) else m for m in expected_mods
]
sketches = list(sketches)
for expected_id, (expected_mod, expected_decision) in enumerate(
zip(expected_mods, expected_decisions)
):
sketch_id = _find_match_sketch_id(mod, sketches, expected_mod, expected_decision)
if sketch_id is None:
raise AssertionError(
f"Expected sketch #{expected_id} doesn't exist in the generated sketches."
)
sketches.pop(sketch_id)


def print_sketches(sketches: List[Schedule]):
for i, sch in enumerate(sketches):
print(f"###### {i}")
print(sch.mod.script())
for inst in sch.trace.insts:
if inst in sch.trace.decisions:
print(f'("{inst.kind.name}", {sch.trace.decisions[inst]}),')
115 changes: 115 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for MetaSchedule search space on CUDA"""
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.space_generation import check_sketches
from tvm.meta_schedule.testing.te_workload import create_te_workload
from tvm.script import tir as T
from tvm.target import Target


def _target():
return Target("nvidia/geforce-rtx-3070")


def test_cuda_c1d():
# fmt: off
@T.prim_func
def mod_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":16})
conv1d_nlc_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
PadInput_shared = T.alloc_buffer([1, 258, 64], dtype="float32", scope="shared")
weight_shared = T.alloc_buffer([3, 64, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_fused in T.thread_binding(4, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_fused in T.thread_binding(16, thread="vthread.x"):
for i0_2_i1_2_i2_2_fused in T.thread_binding(4, thread="threadIdx.x"):
for i3_0, i4_0 in T.grid(1, 16):
for ax0_ax1_ax2_fused in T.serial(260):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4)
v2 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 4)
T.reads(inputs[v0, v1 - 1, v2])
T.writes(PadInput_shared[v0, v1, v2])
T.block_attr({"meta_schedule.cooperative_fetch":4})
PadInput_shared[v0, v1, v2] = T.if_then_else(1 <= v1 and v1 < 257, inputs[v0, v1 - 1, v2], T.float32(0), dtype="float32")
for ax0_ax1_ax2_fused in T.serial(1536):
with T.block("weight_shared"):
v0 = T.axis.spatial(3, ax0_ax1_ax2_fused // 512)
v1 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 512 // 128)
v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128)
T.reads(weight[v0, v1, v2])
T.writes(weight_shared[v0, v1, v2])
T.block_attr({"meta_schedule.cooperative_fetch":3})
weight_shared[v0, v1, v2] = weight[v0, v1, v2]
for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8):
with T.block("conv1d_nlc"):
n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
l = T.axis.spatial(128, (i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4 + i1_4)
co = T.axis.spatial(128, (((0 * 2 + i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 + i2_4)
rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 + i3_2)
rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) * 2 + i4_2)
T.reads(PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc], weight_shared[rl, rc, co])
T.writes(conv1d_nlc_local[n, l, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
with T.init():
conv1d_nlc_local[n, l, co] = T.float32(0)
conv1d_nlc_local[n, l, co] = conv1d_nlc_local[n, l, co] + PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc] * weight_shared[rl, rc, co]
for ax0, ax1, ax2 in T.grid(1, 4, 16):
with T.block("conv1d_nlc_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + ax1)
v2 = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + ax2)
T.reads(conv1d_nlc_local[v0, v1, v2])
T.writes(conv1d_nlc[v0, v1, v2])
conv1d_nlc[v0, v1, v2] = conv1d_nlc_local[v0, v1, v2]
# fmt: on

decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [4, 8, 1, 1, 4]),
("SamplePerfectTile", [1, 2, 4, 2, 8]),
("SamplePerfectTile", [1, 1, 3]),
("SamplePerfectTile", [16, 2, 2]),
("SampleCategorical", 3),
("SampleCategorical", 2),
("SampleCategorical", 1),
]

mod = create_te_workload("C1D", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[mod_0],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
test_cuda_c1d()

0 comments on commit b24f762

Please sign in to comment.