forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Metaschedule] Add TilingwithTensorIntrin
- Loading branch information
Showing
2 changed files
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
#include "auto_tensorize.h" | ||
|
||
#include "../../tir/schedule/analysis.h" | ||
|
||
namespace tvm { | ||
namespace meta_schedule { | ||
|
||
using tir::LoopRV; | ||
|
||
Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, | ||
const String& intrin_name) { | ||
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping( | ||
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); | ||
if (!opt_tensorize_info) return NullOpt; | ||
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); | ||
// Construct a mapping from tir loops back to LoopRVs | ||
Map<tir::StmtSRef, LoopRV> loop2rv; | ||
{ | ||
Array<LoopRV> loop_rvs = sch->GetLoops(block_rv); | ||
for (const LoopRV& loop_rv : loop_rvs) { | ||
loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); | ||
} | ||
} | ||
// Split the loops | ||
arith::Analyzer analyzer; | ||
std::unordered_set<const tir::StmtSRefNode*> inner_loops; | ||
std::vector<LoopRV> reorder_suffix; | ||
reorder_suffix.resize(info->loop_map.size()); | ||
for (const auto& kv : info->loop_map) { | ||
// Extract mapping (block_loop => desc_loop) | ||
const tir::StmtSRef& block_loop_sref = kv.first; | ||
const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>(); | ||
const tir::ForNode* desc_loop = kv.second.get(); | ||
ICHECK(block_loop != nullptr && desc_loop != nullptr); | ||
// Extract the loop extent | ||
PrimExpr block_extent = analyzer.Simplify(block_loop->extent); | ||
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); | ||
const auto* int_block_extent = block_extent.as<IntImmNode>(); | ||
const auto* int_desc_extent = desc_extent.as<IntImmNode>(); | ||
ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); | ||
// Check divisibility | ||
int64_t total = int_block_extent->value; | ||
int64_t inner = int_desc_extent->value; | ||
ICHECK_EQ(total % inner, 0); | ||
int64_t outer = int_block_extent->value / int_desc_extent->value; | ||
// Do the split | ||
Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)}); | ||
ICHECK_EQ(split.size(), 2); | ||
inner_loops.insert(sch->GetSRef(split[1]).operator->()); | ||
// The inner split will be reordered to the loop domain that is tensorized | ||
int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop)); | ||
reorder_suffix[desc_loop_index] = split[1]; | ||
} | ||
// Reorder the loops | ||
std::vector<LoopRV> reorder_list; | ||
bool meet = false; | ||
Array<LoopRV> all_loops = sch->GetLoops(block_rv); | ||
for (const LoopRV& loop : all_loops) { | ||
if (inner_loops.count(sch->GetSRef(loop).operator->())) { | ||
meet = true; | ||
} else if (meet) { | ||
reorder_list.push_back(loop); | ||
} | ||
} | ||
reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end()); | ||
sch->Reorder(reorder_list); | ||
ICHECK(!reorder_suffix.empty()); | ||
return reorder_suffix[0]; | ||
} | ||
|
||
} // namespace meta_schedule | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_ | ||
#define TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_ | ||
|
||
#include <tvm/tir/schedule/schedule.h> | ||
|
||
namespace tvm { | ||
namespace meta_schedule { | ||
|
||
Optional<tir::LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, | ||
const String& intrin_name); | ||
|
||
} // namespace meta_schedule | ||
} // namespace tvm | ||
|
||
#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_ |