Skip to content

Commit

Permalink
[TIR] Add TileWithTensorIntrin (apache#11075)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
6 people authored and Sergey Shtin committed May 17, 2022
1 parent 821f00e commit 6749e09
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .trace import Trace

from . import analysis
from . import transform
42 changes: 42 additions & 0 deletions python/tvm/tir/schedule/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Transformation on TIR schedule."""
from typing import Optional

from tvm.tir.schedule import Schedule, BlockRV, LoopRV
from . import _ffi_api


def tile_with_tensor_intrin(sch: Schedule, block: BlockRV, intrin_name: str) -> Optional[LoopRV]:
"""Tile a subset of loops in the block according to the given tensor intrinsic.
Parameters
----------
sch : Schedule
The schedule to which tiling is applied
block : BlockRV
The block whose subset of loops will be tiled
intrin_name : str
The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand
Returns
-------
tiled_loop_rv : Optional[LoopRV]
LoopRV corresponding to the outermost loop of a block tiled according to the given intrin
NullOpt if no valid loop mapping is found
"""
return _ffi_api.TileWithTensorIntrin(sch, block, intrin_name) # type: ignore
63 changes: 63 additions & 0 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,68 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
throw OnlyLeafError(self->mod, GetRef<Block>(leaf_block), GetRef<Block>(scope_block));
}

Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name) {
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
if (!opt_tensorize_info) return NullOpt;
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
// Construct a mapping from tir loops back to LoopRVs
Map<tir::StmtSRef, LoopRV> loop2rv;
{
Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
for (const LoopRV& loop_rv : loop_rvs) {
loop2rv.Set(sch->GetSRef(loop_rv), loop_rv);
}
}
// Split the loops
arith::Analyzer analyzer;
std::unordered_set<const tir::StmtSRefNode*> inner_loops;
std::vector<LoopRV> reorder_suffix;
reorder_suffix.resize(info->loop_map.size());
for (const auto& kv : info->loop_map) {
// Extract mapping (block_loop => desc_loop)
const tir::StmtSRef& block_loop_sref = kv.first;
const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>();
const tir::ForNode* desc_loop = kv.second.get();
ICHECK(block_loop != nullptr && desc_loop != nullptr);
// Extract the loop extent
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
const auto* int_block_extent = block_extent.as<IntImmNode>();
const auto* int_desc_extent = desc_extent.as<IntImmNode>();
ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr);
// Check divisibility
int64_t total = int_block_extent->value;
int64_t inner = int_desc_extent->value;
ICHECK_EQ(total % inner, 0);
int64_t outer = int_block_extent->value / int_desc_extent->value;
// Do the split
Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)});
ICHECK_EQ(split.size(), 2);
inner_loops.insert(sch->GetSRef(split[1]).operator->());
// The inner split will be reordered to the loop domain that is tensorized
int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop));
reorder_suffix[desc_loop_index] = split[1];
}
// Reorder the loops
std::vector<LoopRV> reorder_list;
bool meet = false;
Array<LoopRV> all_loops = sch->GetLoops(block_rv);
for (const LoopRV& loop : all_loops) {
if (inner_loops.count(sch->GetSRef(loop).operator->())) {
meet = true;
} else if (meet) {
reorder_list.push_back(loop);
}
}
reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end());
sch->Reorder(reorder_list);
ICHECK(!reorder_suffix.empty());
return reorder_suffix[0];
}

TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin);

} // namespace tir
} // namespace tvm
13 changes: 13 additions & 0 deletions src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
#define TVM_TIR_SCHEDULE_TRANSFORM_H_

#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/state.h>

namespace tvm {
Expand Down Expand Up @@ -104,6 +105,18 @@ Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, c
void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref,
Stmt* src_stmt, Stmt* tgt_stmt);

/*!
* \brief Tile a subset of loops in the block according to the given tensor intrinsic.
* \param self The schedule to which tiling is applied
* \param block_rv The block whose subset of loops will be tiled
* \param intrin_name The name of a tensor intrinsic, must be registerd via
* TensorIntrin.register(...) beforehand
* \return LoopRV corresponding to the outermost loop of a
* block tiled according to the given intrin, NullOpt if a valid loop mapping is not found
*/
Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name);

} // namespace tir
} // namespace tvm

Expand Down
181 changes: 181 additions & 0 deletions tests/python/unittest/test_tir_schedule_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN

from tvm.tir import Schedule
from tvm.script import tir as T
from tvm.tir.schedule.transform import tile_with_tensor_intrin


@tvm.script.ir_module
class DenseVNNIModule:
@T.prim_func
def main(
placeholder: T.Buffer[(1024, 1024), "uint8"],
placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
compute: T.Buffer[(1024, 1024), "int32"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
with T.block("root"):
T.reads()
T.writes()
for i0, i1, i2 in T.grid(1024, 1024, 1024):
with T.block("compute"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
T.writes(compute[i, j])
with T.init():
compute[i, j] = 0
compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
)


@tvm.script.ir_module
class DenseVNNIModuleTiled:
@T.prim_func
def main(
placeholder: T.Buffer[(1024, 1024), "uint8"],
placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
compute: T.Buffer[(1024, 1024), "int32"],
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4):
with T.block("compute"):
i = T.axis.spatial(1024, i0)
j = T.axis.spatial(1024, i1_0 * 16 + i1_1)
k = T.axis.reduce(1024, i2_0 * 4 + i2_1)
T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
T.writes(compute[i, j])
with T.init():
compute[i, j] = 0
compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
)


@tvm.script.ir_module
class Conv2dNCHWcVNNIModule:
@T.prim_func
def main(
placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
with T.block("conv2d_NCHWc_int8"):
(
n,
oc_chunk,
oh,
ow,
oc_block,
kh,
kw,
ic_outer,
ic_f_inner,
ic_s_inner,
) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
T.reads(
placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
)
T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
with T.init():
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
n, oc_chunk, oh, ow, oc_block
] + T.cast(
placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
) * T.cast(
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
"int32",
)


@tvm.script.ir_module
class Conv2dNCHWcVNNIModuleTiled:
@T.prim_func
def main(
placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid(
1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4
):
with T.block("conv2d_NCHWc_int8"):
n = T.axis.spatial(1, 0)
oc_chunk, oh, ow, oc_block = T.axis.remap("SSSS", [i1, i2, i3, i4_1])
kh = T.axis.reduce(1, 0)
kw = T.axis.reduce(1, 0)
ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("RRR", [i7, i8, i9_1])
T.reads(
placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
)
T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
with T.init():
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
n, oc_chunk, oh, ow, oc_block
] + T.cast(
placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
) * T.cast(
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
"int32",
)


def test_tile_with_tensor_intrin_dense_vnni():
s = Schedule(DenseVNNIModule)
block = s.get_block("compute")

tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)

_, _, _, i1_1, _ = s.get_loops(block)

assert s.get(tiled_loop) == s.get(i1_1)
tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled)


def test_tile_with_tensor_intrin_conv2d_nchwc_vnni():
s = Schedule(Conv2dNCHWcVNNIModule)
block = s.get_block("conv2d_NCHWc_int8")

tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)

tiled_loops = s.get_loops(block)

assert len(tiled_loops) == 12
assert s.get(tiled_loop) == s.get(tiled_loops[-2])

tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled)


if __name__ == "__main__":
test_tile_with_tensor_intrin_dense_vnni()
test_tile_with_tensor_intrin_conv2d_nchwc_vnni()

0 comments on commit 6749e09

Please sign in to comment.