From 6749e099d102615b1a1cc385609e65d4743f4197 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 21 Apr 2022 11:08:10 +0900 Subject: [PATCH] [TIR] Add TileWithTensorIntrin (#11075) Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Wuwei Lin --- python/tvm/tir/schedule/__init__.py | 1 + python/tvm/tir/schedule/transform.py | 42 ++++ src/tir/schedule/transform.cc | 63 ++++++ src/tir/schedule/transform.h | 13 ++ .../unittest/test_tir_schedule_transform.py | 181 ++++++++++++++++++ 5 files changed, 300 insertions(+) create mode 100644 python/tvm/tir/schedule/transform.py create mode 100644 tests/python/unittest/test_tir_schedule_transform.py diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 66ac7b9d772b..63638a89459e 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -24,3 +24,4 @@ from .trace import Trace from . import analysis +from . import transform diff --git a/python/tvm/tir/schedule/transform.py b/python/tvm/tir/schedule/transform.py new file mode 100644 index 000000000000..5dbc06846d52 --- /dev/null +++ b/python/tvm/tir/schedule/transform.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Transformation on TIR schedule.""" +from typing import Optional + +from tvm.tir.schedule import Schedule, BlockRV, LoopRV +from . import _ffi_api + + +def tile_with_tensor_intrin(sch: Schedule, block: BlockRV, intrin_name: str) -> Optional[LoopRV]: + """Tile a subset of loops in the block according to the given tensor intrinsic. + + Parameters + ---------- + sch : Schedule + The schedule to which tiling is applied + block : BlockRV + The block whose subset of loops will be tiled + intrin_name : str + The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand + + Returns + ------- + tiled_loop_rv : Optional[LoopRV] + LoopRV corresponding to the outermost loop of a block tiled according to the given intrin + NullOpt if no valid loop mapping is found + """ + return _ffi_api.TileWithTensorIntrin(sch, block, intrin_name) # type: ignore diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index ffb6b2d52628..b2e71a9a0d3b 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -136,5 +136,68 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } +Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name) { + Optional opt_tensorize_info = GetTensorizeLoopMapping( + sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); + if (!opt_tensorize_info) return NullOpt; + const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); + // Construct a mapping from tir loops back to LoopRVs + Map loop2rv; + { + Array loop_rvs = sch->GetLoops(block_rv); + for (const LoopRV& loop_rv : loop_rvs) { + loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); + } + } + // Split the loops + arith::Analyzer analyzer; + std::unordered_set inner_loops; + std::vector reorder_suffix; + reorder_suffix.resize(info->loop_map.size()); + for (const auto& kv : info->loop_map) { + // Extract mapping (block_loop => desc_loop) + const tir::StmtSRef& block_loop_sref = kv.first; + const tir::ForNode* block_loop = block_loop_sref->StmtAs(); + const tir::ForNode* desc_loop = kv.second.get(); + ICHECK(block_loop != nullptr && desc_loop != nullptr); + // Extract the loop extent + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + const auto* int_block_extent = block_extent.as(); + const auto* int_desc_extent = desc_extent.as(); + ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); + // Check divisibility + int64_t total = int_block_extent->value; + int64_t inner = int_desc_extent->value; + ICHECK_EQ(total % inner, 0); + int64_t outer = int_block_extent->value / int_desc_extent->value; + // Do the split + Array split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)}); + ICHECK_EQ(split.size(), 2); + inner_loops.insert(sch->GetSRef(split[1]).operator->()); + // The inner split will be reordered to the loop domain that is tensorized + int desc_loop_index = info->desc_loop_indexer.at(GetRef(desc_loop)); + reorder_suffix[desc_loop_index] = split[1]; + } + // Reorder the loops + std::vector reorder_list; + bool meet = false; + Array all_loops = sch->GetLoops(block_rv); + for (const LoopRV& loop : all_loops) { + if (inner_loops.count(sch->GetSRef(loop).operator->())) { + meet = true; + } else if (meet) { + reorder_list.push_back(loop); + } + } + reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end()); + sch->Reorder(reorder_list); + ICHECK(!reorder_suffix.empty()); + return reorder_suffix[0]; +} + +TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 3932c4bdbd3d..12326b3418dd 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -19,6 +19,7 @@ #ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_ #define TVM_TIR_SCHEDULE_TRANSFORM_H_ +#include #include namespace tvm { @@ -104,6 +105,18 @@ Array ReplaceBuffer(Array match_buffers, c void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, Stmt* src_stmt, Stmt* tgt_stmt); +/*! + * \brief Tile a subset of loops in the block according to the given tensor intrinsic. + * \param self The schedule to which tiling is applied + * \param block_rv The block whose subset of loops will be tiled + * \param intrin_name The name of a tensor intrinsic, must be registerd via + * TensorIntrin.register(...) beforehand + * \return LoopRV corresponding to the outermost loop of a + * block tiled according to the given intrin, NullOpt if a valid loop mapping is not found + */ +Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name); + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_transform.py b/tests/python/unittest/test_tir_schedule_transform.py new file mode 100644 index 000000000000..6dfd4315ec90 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_transform.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN + +from tvm.tir import Schedule +from tvm.script import tir as T +from tvm.tir.schedule.transform import tile_with_tensor_intrin + + +@tvm.script.ir_module +class DenseVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1024, 1024), "uint8"], + placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], + compute: T.Buffer[(1024, 1024), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + for i0, i1, i2 in T.grid(1024, 1024, 1024): + with T.block("compute"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) + T.writes(compute[i, j]) + with T.init(): + compute[i, j] = 0 + compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast( + placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32" + ) + + +@tvm.script.ir_module +class DenseVNNIModuleTiled: + @T.prim_func + def main( + placeholder: T.Buffer[(1024, 1024), "uint8"], + placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], + compute: T.Buffer[(1024, 1024), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4): + with T.block("compute"): + i = T.axis.spatial(1024, i0) + j = T.axis.spatial(1024, i1_0 * 16 + i1_1) + k = T.axis.reduce(1024, i2_0 * 4 + i2_1) + T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) + T.writes(compute[i, j]) + with T.init(): + compute[i, j] = 0 + compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast( + placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32" + ) + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + with T.block("conv2d_NCHWc_int8"): + ( + n, + oc_chunk, + oh, + ow, + oc_block, + kh, + kw, + ic_outer, + ic_f_inner, + ic_s_inner, + ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModuleTiled: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid( + 1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4 + ): + with T.block("conv2d_NCHWc_int8"): + n = T.axis.spatial(1, 0) + oc_chunk, oh, ow, oc_block = T.axis.remap("SSSS", [i1, i2, i3, i4_1]) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("RRR", [i7, i8, i9_1]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + +def test_tile_with_tensor_intrin_dense_vnni(): + s = Schedule(DenseVNNIModule) + block = s.get_block("compute") + + tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN) + + _, _, _, i1_1, _ = s.get_loops(block) + + assert s.get(tiled_loop) == s.get(i1_1) + tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled) + + +def test_tile_with_tensor_intrin_conv2d_nchwc_vnni(): + s = Schedule(Conv2dNCHWcVNNIModule) + block = s.get_block("conv2d_NCHWc_int8") + + tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN) + + tiled_loops = s.get_loops(block) + + assert len(tiled_loops) == 12 + assert s.get(tiled_loop) == s.get(tiled_loops[-2]) + + tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled) + + +if __name__ == "__main__": + test_tile_with_tensor_intrin_dense_vnni() + test_tile_with_tensor_intrin_conv2d_nchwc_vnni()