Skip to content

Commit

Permalink
TilingwithTensorIntrin works
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent 86baa31 commit 2b53437
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ std::vector<State> SubRule(std::vector<State> states, FLambda sub_rule) {
*/
class MultiLevelTilingNode : public ScheduleRuleNode {
public:
inline std::vector<State> TileForVNNI(State state) const;
// SubRule 1. add write cache
inline std::vector<State> AddWriteReuse(State state) const;
// SubRule 2. tile the loop nest
Expand All @@ -390,7 +391,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
}
sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);

LOG(INFO) << "Doing multi level tiling";
std::vector<State> states{State(sch, block_rv)};
states = SubRule(std::move(states), [&](State state) { return TileForVNNI(state); });
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); });
Expand Down Expand Up @@ -444,6 +447,19 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode);
};

inline std::vector<State> MultiLevelTilingNode::TileForVNNI(State state) const {
std::vector<State> result;
BlockRV block_rv = state.block_rv;
const std::string intrin_name = "dot_16x1x16_uint8_int8_int32_cascadelake";
Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin(state.sch, block_rv, intrin_name);
ICHECK(tiled_loop_rv.defined());
LOG(INFO) << "After TilingwithTensorIntrin" << state.sch->mod();
state.block_rv = state.sch->Blockize(tiled_loop_rv.value());
state.sch->Annotate(block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
result.push_back(state);
return result;
}

inline std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
const ReuseConfig& config = this->reuse_write_;
if (config.req == ReuseType::kNoReuse) {
Expand Down Expand Up @@ -503,6 +519,7 @@ inline std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
LOG(INFO) << "Tile loops: " << loops.size();
for (int i = 0, n = loops.size(); i < n; ++i) {
LoopRV loop = loops[i];
const std::vector<int>* idx = nullptr;
Expand Down

0 comments on commit 2b53437

Please sign in to comment.