Skip to content

Commit

Permalink
[Metaschedule] Add TilingwithTensorIntrin
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 20, 2022
1 parent 3823b39 commit 70ac9c1
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
91 changes: 91 additions & 0 deletions src/meta_schedule/schedule_rule/auto_tensorize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "auto_tensorize.h"

#include "../../tir/schedule/analysis.h"

namespace tvm {
namespace meta_schedule {

using tir::LoopRV;

Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name) {
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
if (!opt_tensorize_info) return NullOpt;
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
// Construct a mapping from tir loops back to LoopRVs
Map<tir::StmtSRef, LoopRV> loop2rv;
{
Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
for (const LoopRV& loop_rv : loop_rvs) {
loop2rv.Set(sch->GetSRef(loop_rv), loop_rv);
}
}
// Split the loops
arith::Analyzer analyzer;
std::unordered_set<const tir::StmtSRefNode*> inner_loops;
std::vector<LoopRV> reorder_suffix;
reorder_suffix.resize(info->loop_map.size());
for (const auto& kv : info->loop_map) {
// Extract mapping (block_loop => desc_loop)
const tir::StmtSRef& block_loop_sref = kv.first;
const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>();
const tir::ForNode* desc_loop = kv.second.get();
ICHECK(block_loop != nullptr && desc_loop != nullptr);
// Extract the loop extent
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
const auto* int_block_extent = block_extent.as<IntImmNode>();
const auto* int_desc_extent = desc_extent.as<IntImmNode>();
ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr);
// Check divisibility
int64_t total = int_block_extent->value;
int64_t inner = int_desc_extent->value;
ICHECK_EQ(total % inner, 0);
int64_t outer = int_block_extent->value / int_desc_extent->value;
// Do the split
Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)});
ICHECK_EQ(split.size(), 2);
inner_loops.insert(sch->GetSRef(split[1]).operator->());
// The inner split will be reordered to the loop domain that is tensorized
int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop));
reorder_suffix[desc_loop_index] = split[1];
}
// Reorder the loops
std::vector<LoopRV> reorder_list;
bool meet = false;
Array<LoopRV> all_loops = sch->GetLoops(block_rv);
for (const LoopRV& loop : all_loops) {
if (inner_loops.count(sch->GetSRef(loop).operator->())) {
meet = true;
} else if (meet) {
reorder_list.push_back(loop);
}
}
reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end());
sch->Reorder(reorder_list);
ICHECK(!reorder_suffix.empty());
return reorder_suffix[0];
}

} // namespace meta_schedule
} // namespace tvm
33 changes: 33 additions & 0 deletions src/meta_schedule/schedule_rule/auto_tensorize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_
#define TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_

#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {

Optional<tir::LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name);

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_

0 comments on commit 70ac9c1

Please sign in to comment.