diff --git a/src/meta_schedule/schedule_rule/auto_tensorize.cc b/src/meta_schedule/schedule_rule/auto_tensorize.cc new file mode 100644 index 0000000000000..1f7461e69d85a --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_tensorize.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "auto_tensorize.h" + +#include "../../tir/schedule/analysis.h" + +namespace tvm { +namespace meta_schedule { + +using tir::LoopRV; + +Optional TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name) { + Optional opt_tensorize_info = GetTensorizeLoopMapping( + sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); + if (!opt_tensorize_info) return NullOpt; + const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); + // Construct a mapping from tir loops back to LoopRVs + Map loop2rv; + { + Array loop_rvs = sch->GetLoops(block_rv); + for (const LoopRV& loop_rv : loop_rvs) { + loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); + } + } + // Split the loops + arith::Analyzer analyzer; + std::unordered_set inner_loops; + std::vector reorder_suffix; + reorder_suffix.resize(info->loop_map.size()); + for (const auto& kv : info->loop_map) { + // Extract mapping (block_loop => desc_loop) + const tir::StmtSRef& block_loop_sref = kv.first; + const tir::ForNode* block_loop = block_loop_sref->StmtAs(); + const tir::ForNode* desc_loop = kv.second.get(); + ICHECK(block_loop != nullptr && desc_loop != nullptr); + // Extract the loop extent + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + const auto* int_block_extent = block_extent.as(); + const auto* int_desc_extent = desc_extent.as(); + ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); + // Check divisibility + int64_t total = int_block_extent->value; + int64_t inner = int_desc_extent->value; + ICHECK_EQ(total % inner, 0); + int64_t outer = int_block_extent->value / int_desc_extent->value; + // Do the split + Array split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)}); + ICHECK_EQ(split.size(), 2); + inner_loops.insert(sch->GetSRef(split[1]).operator->()); + // The inner split will be reordered to the loop domain that is tensorized + int desc_loop_index = info->desc_loop_indexer.at(GetRef(desc_loop)); + reorder_suffix[desc_loop_index] = split[1]; + } + // Reorder the loops + std::vector reorder_list; + bool meet = false; + Array all_loops = sch->GetLoops(block_rv); + for (const LoopRV& loop : all_loops) { + if (inner_loops.count(sch->GetSRef(loop).operator->())) { + meet = true; + } else if (meet) { + reorder_list.push_back(loop); + } + } + reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end()); + sch->Reorder(reorder_list); + ICHECK(!reorder_suffix.empty()); + return reorder_suffix[0]; +} + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_tensorize.h b/src/meta_schedule/schedule_rule/auto_tensorize.h new file mode 100644 index 0000000000000..6dda4ffc07dfc --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_tensorize.h @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_ +#define TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +Optional TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name); + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_