From ed4a8ccdb6e101e92870e69c99871aed39036715 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 9 Nov 2021 08:21:31 -0800 Subject: [PATCH] [Tests] Add unittests for auto-inline and multi-level-tiling (#508) --- .../meta_schedule/testing/space_generation.py | 65 ++++ .../schedule_rule/multi_level_tiling.cc | 9 +- ...meta_schedule_schedule_rule_auto_inline.py | 288 ++++++++++++++++++ ...hedule_schedule_rule_multi_level_tiling.py | 276 +++++++++++++++++ .../unittest/test_meta_schedule_sketch_cpu.py | 79 ++--- .../test_meta_schedule_sketch_cuda.py | 74 ++--- 6 files changed, 676 insertions(+), 115 deletions(-) create mode 100644 python/tvm/meta_schedule/testing/space_generation.py create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py new file mode 100644 index 0000000000..4abf090ddf --- /dev/null +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List, Union + +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.target import Target +from tvm.tir import PrimFunc, Schedule +from tvm.tir.schedule import Trace + +from . import schedule_rule as sch_rule + + +def create_context(mod: Union[IRModule, PrimFunc], target: Target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=sch_rule.get(target), + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for rule in ctx.sch_rules: + rule.initialize_with_tune_context(ctx) + return ctx + + +def check_trace(spaces: List[Schedule], expected: List[List[str]]): + expected_traces = {"\n".join(t) for t in expected} + actual_traces = set() + for space in spaces: + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + str_trace = "\n".join(str(trace).strip().splitlines()) + actual_traces.add(str_trace) + assert str_trace in expected_traces, "\n" + str_trace + assert len(expected_traces) == len(actual_traces) + + +def debug_print_spaces(spaces: List[Schedule], trace_as_list: bool) -> None: + for i, space in enumerate(spaces): + print(f"##### Space {i}") + print(space.mod.script()) + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + if trace_as_list: + print(str(trace).strip().splitlines()) + else: + print(trace) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index dabfc80526..62884b6bc9 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -241,11 +241,12 @@ inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const results.emplace_back(std::move(new_state)); } // Case 3. Add one write cache - state.write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, - /*storage_scope=*/config.scope); + BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, + /*storage_scope=*/config.scope); + state.write_cache = write_cache; { - tir::Annotate(state.sch->state(), state.sch->GetSRef(state.write_cache.value()), // - tir::attr::meta_schedule_cache_type, // + tir::Annotate(state.sch->state(), state.sch->GetSRef(write_cache), // + tir::attr::meta_schedule_cache_type, // Integer(tir::attr::meta_schedule_cache_type_write)); } diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index ae60803d08..a05ddaf568 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -15,3 +15,291 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import ( + auto_inline, + auto_inline_after_tiling, +) +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Conv2DBiasBnReLU: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([X[i0_1, i1_1, i2_1 - 1, i3_1 - 1]]) + T.writes([pad_temp[i0_1, i1_1, i2_1, i3_1]]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads([compute_1[nn, ff, yy, xx], pad_temp[nn, rc, yy + ry, xx + rx], W[ff, rc, ry, rx]]) + T.writes([compute_1[nn, ff, yy, xx]]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bias_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([compute_1[i, j, k, l], B[j, 0, 0]]) + T.writes([bias_add[i, j, k, l]]) + bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_mul"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([bias_add[i, j, k, l], bn_scale[j, 0, 0]]) + T.writes([bn_mul[i, j, k, l]]) + bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([bn_mul[i, j, k, l], bn_offset[j, 0, 0]]) + T.writes([bn_add[i, j, k, l]]) + bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([bn_add[i0_2, i1_2, i2_2, i3_2]]) + T.writes([compute[i0_2, i1_2, i2_2, i3_2]]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0)) + + +@tvm.script.ir_module +class Conv2DBiasBnReLUInlined: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([X[i0_1, i1_1, i2_1 - 1, i3_1 - 1]]) + T.writes([pad_temp[i0_1, i1_1, i2_1, i3_1]]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads([compute_1[nn, ff, yy, xx], pad_temp[nn, rc, yy + ry, xx + rx], W[ff, rc, ry, rx]]) + T.writes([compute_1[nn, ff, yy, xx]]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([compute_1[i0_2, i1_2, i2_2, i3_2], B[i1_2, 0, 0], bn_scale[i1_2, 0, 0], bn_offset[i1_2, 0, 0]]) + T.writes([compute[i0_2, i1_2, i2_2, i3_2]]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class NeedsInlinePaddingAndEpilogue: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") + W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([X[i0_1, i1_1, i2_1 - 1, i3_1 - 1]]) + T.writes([pad_temp[i0_1, i1_1, i2_1, i3_1]]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.lazy_cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): + with T.block("pad_temp_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) + v3 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) % 30) + T.reads([pad_temp[v0, v1, v2, v3]]) + T.writes([pad_temp_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + pad_temp_shared[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.lazy_cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): + with T.block("W_shared"): + v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) + v2 = T.axis.spatial(3, i5_0) + v3 = T.axis.spatial(3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) + T.reads([W[v0, v1, v2, v3]]) + T.writes([W_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] + for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + T.reads([compute_local[nn, ff, yy, xx], pad_temp_shared[nn, rc, yy + ry, xx + rx], W_shared[ff, rc, ry, rx]]) + T.writes([compute_local[nn, ff, yy, xx]]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + T.block_attr({"meta_schedule.cache_type":1}) + T.reads([compute_local[v0, v1, v2, v3]]) + T.writes([compute_1[v0, v1, v2, v3]]) + compute_1[v0, v1, v2, v3] = compute_local[v0, v1, v2, v3] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([compute_1[i0_2, i1_2, i2_2, i3_2], B[i1_2, 0, 0], bn_scale[i1_2, 0, 0], bn_offset[i1_2, 0, 0]]) + T.writes([compute[i0_2, i1_2, i2_2, i3_2]]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class PaddingAndEpilogueInlined: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") + W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.lazy_cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): + with T.block("pad_temp_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) + v3 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) % 30) + T.reads([X[v0, v1, v2 - 1, v3 - 1]]) + T.writes([pad_temp_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + pad_temp_shared[v0, v1, v2, v3] = T.if_then_else(v2 >= 1 and v2 < 57 and v3 >= 1 and v3 < 57, X[v0, v1, v2 - 1, v3 - 1], T.float32(0), dtype="float32") + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.lazy_cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): + with T.block("W_shared"): + v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) + v2 = T.axis.spatial(3, i5_0) + v3 = T.axis.spatial(3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) + T.reads([W[v0, v1, v2, v3]]) + T.writes([W_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] + for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + T.reads([compute_local[nn, ff, yy, xx], pad_temp_shared[nn, rc, yy + ry, xx + rx], W_shared[ff, rc, ry, rx]]) + T.writes([compute_local[nn, ff, yy, xx]]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + T.reads([compute_local[v0, v1, v2, v3], B[v1, 0, 0], bn_scale[v1, 0, 0], bn_offset[v1, 0, 0]]) + T.writes([compute[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":1}) + compute[v0, v1, v2, v3] = T.max((compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) * bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0], T.float32(0)) + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_inline_consumer_chain(): + mod = Conv2DBiasBnReLU + target = Target("llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined) + + +def test_inline_into_cache(): + mod = NeedsInlinePaddingAndEpilogue + target = Target("cuda", host="llvm") + ctx = _create_context( + mod=NeedsInlinePaddingAndEpilogue, + target=target, + rule=auto_inline_after_tiling(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=PaddingAndEpilogueInlined) + + +if __name__ == "__main__": + test_inline_consumer_chain() + test_inline_into_cache() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py new file mode 100644 index 0000000000..03f35749f0 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import multi_level_tiling +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.te import create_prim_func +from tvm.meta_schedule.testing import te_workload +from tvm.target import Target + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_cpu_matmul(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cpu_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + # pylint: enable=line-too-long + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cuda_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", + "l44, l45 = sch.split(loop=l41, factors=[v42, v43])", + "sch.vectorize(loop=l45)", + 'sch.annotate(block_or_loop=l44, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b46, loop=l28, preserve_unit_loops=1)", + "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", + "l53 = sch.fuse(l51, l52)", + "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", + "l56, l57 = sch.split(loop=l53, factors=[v54, v55])", + "sch.vectorize(loop=l57)", + 'sch.annotate(block_or_loop=l56, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", + ] + ] + # pylint: enable=line-too-long + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_cuda_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", + "l44, l45 = sch.split(loop=l41, factors=[v42, v43])", + "sch.vectorize(loop=l45)", + 'sch.annotate(block_or_loop=l44, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b46, loop=l28, preserve_unit_loops=1)", + "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", + "l53 = sch.fuse(l51, l52)", + "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", + "l56, l57 = sch.split(loop=l53, factors=[v54, v55])", + "sch.vectorize(loop=l57)", + 'sch.annotate(block_or_loop=l56, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", + ] + ] + # pylint: enable=line-too-long + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_matmul_relu() + test_cuda_matmul() + test_cuda_matmul_relu() diff --git a/tests/python/unittest/test_meta_schedule_sketch_cpu.py b/tests/python/unittest/test_meta_schedule_sketch_cpu.py index b5dfdadaa1..cdcc518b69 100644 --- a/tests/python/unittest/test_meta_schedule_sketch_cpu.py +++ b/tests/python/unittest/test_meta_schedule_sketch_cpu.py @@ -18,52 +18,14 @@ from typing import List -from tvm import meta_schedule as ms from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import check_trace, create_context from tvm.target import Target from tvm.te import create_prim_func -from tvm.tir.schedule import Trace -from tvm.tir.schedule.schedule import Schedule -def _create_context(mod): - from tvm.meta_schedule.testing import ( # pylint: disable=import-outside-toplevel - schedule_rule as sch_rules, - ) - - target = Target("llvm") - ctx = ms.TuneContext( - mod=mod, - target=target, - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=sch_rules.get(target), - task_name="test", - ) - ctx.space_generator.initialize_with_tune_context(ctx) - for rule in ctx.sch_rules: - rule.initialize_with_tune_context(ctx) - return ctx - - -def _check_trace(spaces: List[Schedule], expected: List[List[str]]): - expected_traces = {"\n".join(t) for t in expected} - actual_traces = set() - for space in spaces: - trace = Trace(space.trace.insts, {}) - trace = trace.simplified(remove_postproc=True) - str_trace = "\n".join(str(trace).strip().splitlines()) - actual_traces.add(str_trace) - assert str_trace in expected_traces, "\n" + str_trace - assert len(expected_traces) == len(actual_traces) - - -def _debug_print(spaces): - for i, space in enumerate(spaces): - print(f"##### Space {i}") - print(space.mod.script()) - trace = Trace(space.trace.insts, {}) - trace = trace.simplified(remove_postproc=True) - print(str(trace).strip().splitlines()) +def _target() -> Target: + return Target("llvm") def test_meta_schedule_cpu_sketch_matmul(): @@ -106,18 +68,19 @@ def test_meta_schedule_cpu_sketch_matmul(): "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", ], ] - ctx = _create_context( + ctx = create_context( create_prim_func( te_workload.matmul( n=512, m=512, k=512, ) - ) + ), + target=_target(), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 3 - _check_trace(spaces, expected) + check_trace(spaces, expected) def test_meta_schedule_cpu_sketch_matmul_relu(): @@ -162,18 +125,19 @@ def test_meta_schedule_cpu_sketch_matmul_relu(): ], ] # pylint: enable=line-too-long - ctx = _create_context( + ctx = create_context( create_prim_func( te_workload.matmul_relu( n=512, m=512, k=512, ) - ) + ), + target=_target(), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 3 - _check_trace(spaces, expected) + check_trace(spaces, expected) def test_meta_schedule_cpu_sketch_conv2d_nchw(): @@ -242,7 +206,7 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw(): ], ] # pylint: enable=line-too-long - ctx = _create_context( + ctx = create_context( create_prim_func( te_workload.conv2d_nchw( n=1, @@ -255,11 +219,12 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw(): stride=1, padding=1, ) - ) + ), + target=_target(), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 3 - _check_trace(spaces, expected) + check_trace(spaces, expected) def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name @@ -346,7 +311,7 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable ], ] # pylint: enable=line-too-long - ctx = _create_context( + ctx = create_context( create_prim_func( te_workload.conv2d_nchw_bias_bn_relu( n=1, @@ -359,16 +324,17 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable stride=1, padding=1, ) - ) + ), + target=_target(), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 3 - _check_trace(spaces, expected) + check_trace(spaces, expected) def test_meta_schedule_sketch_cpu_max_pool2d_nchw(): expected: List[List[str]] = [[]] - ctx = _create_context( + ctx = create_context( create_prim_func( te_workload.max_pool2d_nchw( n=1, @@ -377,11 +343,12 @@ def test_meta_schedule_sketch_cpu_max_pool2d_nchw(): ci=512, padding=1, ) - ) + ), + target=_target(), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 1 - _check_trace(spaces, expected) + check_trace(spaces, expected) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py index 9d18b0f596..0e6be26fa0 100644 --- a/tests/python/unittest/test_meta_schedule_sketch_cuda.py +++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py @@ -16,54 +16,14 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -from typing import List - -from tvm import meta_schedule as ms from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import check_trace, create_context from tvm.target import Target from tvm.te import create_prim_func -from tvm.tir.schedule import Trace -from tvm.tir.schedule.schedule import Schedule - - -def _create_context(mod): - from tvm.meta_schedule.testing import ( # pylint: disable=import-outside-toplevel - schedule_rule as sch_rules, - ) - - target = Target("cuda", host="llvm") - ctx = ms.TuneContext( - mod=mod, - target=target, - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=sch_rules.get(target), - task_name="test", - ) - ctx.space_generator.initialize_with_tune_context(ctx) - for rule in ctx.sch_rules: - rule.initialize_with_tune_context(ctx) - return ctx - - -def _check_trace(spaces: List[Schedule], expected: List[List[str]]): - expected_traces = {"\n".join(t) for t in expected} - actual_traces = set() - for space in spaces: - trace = Trace(space.trace.insts, {}) - trace = trace.simplified(remove_postproc=True) - str_trace = "\n".join(str(trace).strip().splitlines()) - actual_traces.add(str_trace) - assert str_trace in expected_traces, "\n" + str_trace - assert len(expected_traces) == len(actual_traces) -def _debug_print(spaces: List[Schedule]) -> None: - for i, space in enumerate(spaces): - print(f"##### Space {i}") - print(space.mod.script()) - trace = Trace(space.trace.insts, {}) - trace = trace.simplified(remove_postproc=True) - print(str(trace).strip().splitlines()) +def _target() -> Target: + return Target("cuda", host="llvm") def test_meta_schedule_cuda_sketch_matmul(): @@ -106,18 +66,19 @@ def test_meta_schedule_cuda_sketch_matmul(): ] ] # pylint: enable=line-too-long - ctx = _create_context( + ctx = create_context( create_prim_func( te_workload.matmul( n=512, m=512, k=512, ) - ) + ), + target=_target(), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 1 - _check_trace(spaces, expected) + check_trace(spaces, expected) def test_meta_schedule_cuda_sketch_matmul_relu(): @@ -162,18 +123,19 @@ def test_meta_schedule_cuda_sketch_matmul_relu(): ] ] # pylint: enable=line-too-long - ctx = _create_context( + ctx = create_context( create_prim_func( te_workload.matmul_relu( n=512, m=512, k=512, ) - ) + ), + target=_target(), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 1 - _check_trace(spaces, expected) + check_trace(spaces, expected) def test_meta_schedule_cuda_sketch_conv2d_nchw(): @@ -226,7 +188,7 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw(): ] ] # pylint: enable=line-too-long - ctx = _create_context( + ctx = create_context( create_prim_func( te_workload.conv2d_nchw( n=1, @@ -239,12 +201,13 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw(): stride=1, padding=1, ) - ) + ), + target=_target(), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 1 - _check_trace(spaces, expected) + check_trace(spaces, expected) def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name @@ -305,7 +268,7 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disabl ] ] # pylint: enable=line-too-long - ctx = _create_context( + ctx = create_context( create_prim_func( te_workload.conv2d_nchw_bias_bn_relu( n=1, @@ -318,12 +281,13 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disabl stride=1, padding=1, ) - ) + ), + target=_target(), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 1 - _check_trace(spaces, expected) + check_trace(spaces, expected) if __name__ == "__main__":