Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metaschedule] MultiLevelTiling for wide vector architectures #12845

Merged
merged 4 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,21 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);

/*!
* \brief Extension of MultiLevelTiling for backends with wide vectors.
* The loop over the innermost spatial axis of the output buffer is always vectorized with the
* maximum vector length.
* \param structure The tiling structure. 'SSRSRS' is recommended.
* \param vector_length_in_bits The length of a vector register in bits.
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingWideVector(
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
MultiLevelTilingWithIntrin,
ReuseType,
MultiLevelTilingTensorCore,
MultiLevelTilingWideVector,
)
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,40 @@ def __init__(
reuse_write.as_dict() if reuse_write is not None else None,
use_software_pipeline,
)


@register_object("meta_schedule.MultiLevelTilingWideVector")
class MultiLevelTilingWideVector(ScheduleRule):
"""Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost
spatial axis of the output buffer is always vectorized with the maximum vector length.
Parameters
----------
structure : str
The tiling structure. 'SSRSRS' is recommended.
vector_length_in_bits: int
The length of a vector register in bits.
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
reuse_read : Optional[ReuseType]
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
"""

def __init__(
self,
structure: str,
vector_length_in_bits: int,
max_innermost_factor: Optional[int] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingWideVector, # type: ignore # pylint: disable=no-member
structure,
vector_length_in_bits,
max_innermost_factor,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
35 changes: 24 additions & 11 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
return results;
}

Array<tir::LoopRV> MultiLevelTilingNode::SplitLoop(Schedule& sch, BlockRV block, LoopRV loop,
int n_tiles) const {
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
/*max_innermost_factor=*/max_innermost_factor);
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
/*factors=*/{factors.begin(), factors.end()});
return splits;
}

std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
Schedule& sch = state->sch;
const BlockRV& block_rv = state->block_rv;
Expand All @@ -179,6 +190,7 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
for (int i = 0, n = loops.size(); i < n; ++i) {
LoopRV loop = loops[i];
const std::vector<int>* idx = nullptr;

if (iter_types[i] == IterVarType::kDataPar) {
idx = &s_indices_;
if (spatial_loop_product != -1) {
Expand All @@ -193,17 +205,18 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
} else {
continue;
}
// Do the split
int n_tiles = idx->size();
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
/*max_innermost_factor=*/max_innermost_factor);
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
/*factors=*/{factors.begin(), factors.end()});
// Put every tile to its slot
for (int j = 0; j < n_tiles; ++j) {
tiles[idx->at(j)].push_back(splits[j]);

const int n_tiles = idx->size();

if (n_tiles == 1) {
tiles[idx->at(0)].push_back(loop);
} else {
auto splits = SplitLoop(sch, block_rv, loop, n_tiles);

// Put every tile to its slot
for (int j = 0; j < n_tiles; ++j) {
tiles[idx->at(j)].push_back(splits[j]);
}
}
}
// Step 3. Reorder to organize the tiles
Expand Down
3 changes: 3 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states);

virtual Array<tir::LoopRV> SplitLoop(tir::Schedule& sch, tir::BlockRV block, tir::LoopRV loop,
int n_tiles) const;

// Annotate a block to use cooperative fetching
void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const;

Expand Down
120 changes: 120 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "../../tir/schedule/analysis.h"
#include "../../tir/schedule/transform.h"
#include "../utils.h"
#include "multi_level_tiling.h"

namespace tvm {
namespace meta_schedule {

using tir::BlockRV;
using tir::LoopRV;
using tir::Schedule;

/*!
* \brief Extension of MultiLevelTiling for backends with wide vectors.
* The loop over the innermost spatial axis of the output buffer is always vectorized with the
* maximum vector length.
*/
class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode {
public:
size_t vector_length_in_bits;

static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWideVector";
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWideVectorNode, MultiLevelTilingNode);

protected:
Array<tir::LoopRV> SplitLoop(Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const;
};

Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(Schedule& sch, BlockRV block_rv,
LoopRV loop_rv, int n_tiles) const {
const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv));
const tir::StmtSRef block_sref = sch->GetSRef(block_rv);
const tir::BlockNode* block_node = block_sref->StmtAs<tir::BlockNode>();
const tir::BlockRealize block_realize = tir::GetBlockRealize(sch->state(), block_sref);
ICHECK(block_node && block_node->writes.size() == 1);

const auto out_dtype = block_node->writes[0]->buffer->dtype;
const int vec_len = vector_length_in_bits / out_dtype.bits();

// Determine if this loop is over the innermost axis of the output buffer.
// In the example below, we look for a loop whose loop var is bound to the axis co.

// for (i0, 0, 1) {
// for (i1, 0, 56) {
// for (i2, 0, 56) {
// for (i3, 0, 64) {
// for (i4, 0, 3) {
// for (i5, 0, 3) {
// for (i6, 0, 64) {
// block conv2d_nhwc(...) {
// ...
// bind(co, i3)
// ...
// writes([conv2d_nhwc[n, h, w, co]])
// ...
// conv2d_nhwc[n, h, w, co] = ...
// }
const size_t innermost_axis = block_node->writes[0]->region.size() - 1;
const PrimExpr innermost_iter_value = block_realize->iter_values[innermost_axis];

if (!arith::Analyzer().CanProve(loop->loop_var == innermost_iter_value)) {
// If this is not the innermost spatial loop, split the loop in the normal way.
return MultiLevelTilingNode::SplitLoop(sch, block_rv, loop_rv, n_tiles);
} else {
// We split the innermost spatial loop in a way that always uses the maximum vector length.
const int64_t* extent_int = tir::GetLoopIntExtent(loop);
if (extent_int && *extent_int > vec_len) {
Array<tir::LoopRV> inner_splits = sch->Split(/*loop=*/loop_rv,
/*factors=*/{NullOpt, PrimExpr(vec_len)});
Array<tir::ExprRV> outer_factors = sch->SamplePerfectTile(
/*loop=*/inner_splits[0],
/*n=*/n_tiles - 1,
/*max_innermost_factor=*/max_innermost_factor);
Array<tir::LoopRV> outer_splits = sch->Split(
/*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()});
outer_splits.push_back(inner_splits[1]);
return outer_splits;
} else {
Array<tir::ExprRV> factors(n_tiles - 1, PrimExpr(1));
factors.push_back(loop->extent);
return sch->Split(/*loop=*/loop_rv,
/*factors=*/{factors.begin(), factors.end()});
}
}
}

ScheduleRule ScheduleRule::MultiLevelTilingWideVector(
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
auto node = MultiLevelTilingInitCommon<MultiLevelTilingWideVectorNode>(
structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write);
node->vector_length_in_bits = vector_length_in_bits->value;
return ScheduleRule(node);
}

TVM_REGISTER_NODE_TYPE(MultiLevelTilingWideVectorNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector")
.set_body_typed(ScheduleRule::MultiLevelTilingWideVector);

} // namespace meta_schedule
} // namespace tvm
108 changes: 107 additions & 1 deletion tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
from tvm import meta_schedule as ms
from tvm import te
from tvm import te, target
from tvm.meta_schedule.testing import te_workload
from tvm.meta_schedule.testing.schedule_rule import get_rules
from tvm.meta_schedule.testing.space_generation import check_sketches
Expand Down Expand Up @@ -521,9 +521,115 @@ def sum_with_trivial_block_iter(
assert not sch.trace.simplified(remove_postproc=True).insts


def test_multi_level_tiling_hexagon():
@T.prim_func
def cpu_conv2d_nhwc(
inputs: T.Buffer[(1, 56, 56, 64), "float16"],
weight: T.Buffer[(3, 3, 64, 64), "float16"],
conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float16"],
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float16")
for i0, i1, i2, i3 in T.grid(1, 58, 58, 64):
with T.block("PadInput"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57,
inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1],
T.float16(0),
dtype="float16",
)
for (
i0_0,
i1_0,
i2_0,
i3_0,
i4_0,
i5_0,
i6_0,
i0_1_1,
i1_1_1,
i2_1_1,
i3_1_1,
i4_1,
i5_1,
i6_1,
i0_2,
i1_2,
i2_2,
i3_2,
) in T.grid(1, 1, 2, 1, 3, 3, 16, 1, 14, 2, 1, 1, 1, 4, 1, 4, 14, 64):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_0)
h = T.axis.spatial(56, i1_0 * 56 + i1_1_1 * 4 + i1_2)
w = T.axis.spatial(56, i2_0 * 28 + i2_1_1 * 14 + i2_2)
co = T.axis.spatial(64, i3_0 * 64 + i3_1_1 * 64 + i3_2)
rh = T.axis.reduce(3, i4_1 + i4_0)
rw = T.axis.reduce(3, i5_0 + i5_1)
rc = T.axis.reduce(64, i6_0 * 4 + i6_1)
T.reads(PadInput[n, h + rh, w + rw, co // 64 * 64 + rc], weight[rh, rw, rc, co])
T.writes(conv2d_nhwc[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure": "SRSRS"})
with T.init():
conv2d_nhwc[n, h, w, co] = T.float16(0)
conv2d_nhwc[n, h, w, co] = (
conv2d_nhwc[n, h, w, co]
+ PadInput[n, h + rh, w + rw, co // 64 * 64 + rc] * weight[rh, rw, rc, co]
)

target_hexagon = target.hexagon("v69", num_cores=4)

I = 64
O = 64
H = 56
W = 56

mod = te.create_prim_func(
te_workload.conv2d_nhwc(1, H, W, I, O, 3, 1, 1, 1, in_dtype="float16", out_dtype="float16")
)

actual = ms.TuneContext(
mod=mod,
target=Target(target_hexagon, host=target_hexagon),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[
ms.schedule_rule.MultiLevelTilingWideVector(
structure="SRSRS",
vector_length_in_bits=1024,
max_innermost_factor=64,
reuse_read=None,
reuse_write=None,
)
],
task_name="test",
).generate_design_space()

decision_0 = [
("SamplePerfectTile", [1, 1, 1]),
("SamplePerfectTile", [1, 14, 4]),
("SamplePerfectTile", [2, 2, 14]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [16, 4]),
]

check_sketches(
mod,
sketches=actual,
expected_mods=[cpu_conv2d_nhwc],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
test_cpu_matmul()
test_cpu_matmul_relu()
test_cuda_matmul()
test_cuda_matmul_relu()
test_cuda_sum_with_trivial_block_iter()
test_multi_level_tiling_hexagon()