From 9ec0974d24763ee7e43900a16ec6b44564f80704 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Apr 2022 05:50:28 +0900 Subject: [PATCH] bring back logic in original code to support loop permutation --- src/tir/schedule/analysis/analysis.cc | 59 +++++++++++++++---- .../unittest/test_tir_schedule_analysis.py | 24 ++++---- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index efd3db46ca09..73cd283cb2de 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include "../utils.h" @@ -2106,22 +2107,58 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, int next_block_ind = block_loops.size() - 1; for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { - const tir::ForNode* desc_loop = desc_loops[i_desc]; + // Step 4.2. Find the corresponding loop of the i-th block var of desc + const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; + const tir::ForNode* desc_loop = nullptr; + IterVarType iter_type_desc; + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + if (!UsesVar(r, + [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { + desc_loop = desc_loops[i]; + iter_type_desc = iter_types_desc[i]; + break; + } + } + if (desc_loop == nullptr || desc_loop->extent.as() == nullptr) { + return NullOpt; + } + const IntImmNode* int_desc_extent = desc_loop->extent.as(); - if (!int_desc_extent) continue; + const tir::ForNode* block_loop = nullptr; + + PrimExpr block_bind; for (int i_block = next_block_ind; i_block >= 0; --i_block) { - const tir::ForNode* block_loop = block_loops[i_block]; - const IntImmNode* int_block_extent = block_loop->extent.as(); + if (iter_types_block[i_block] == iter_type_desc) { + next_block_ind = i_block - 1; + block_bind = block->iter_values[i_block]; + break; + } + } - if (!int_block_extent) continue; - if (int_block_extent->value % int_desc_extent->value != 0) continue; - if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue; + for (int i = 0, n = block_loops.size(); i < n; ++i) { + PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + if (!UsesVar(r, + [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { + block_loop = block_loops[i]; + const IntImmNode* int_block_extent = block_loop->extent.as(); - const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; - ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); - next_block_ind = i_block - 1; - break; + if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { + return NullOpt; + } + + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + auto it = ret->loop_map.find(block_loop_sref); + if (it == ret->loop_map.end()) { + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + } else if ((*it).second.get() != desc_loop) { + return NullOpt; + } + + break; + } } } diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index c9bee54073ca..9cae3e9815b5 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -240,20 +240,24 @@ def matmul_16x16x16xf16f16f16_desc( s = Schedule(matmul) block = s.get_block("C") + i0, i1, i2 = s.get_loops(block) + desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) - info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) - - desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + for do_reorder in [True, False]: + # Mapping should be invariant to the loop permutation + if do_reorder: + s.reorder(i2, i0, i1) - desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) - i0, i1, i2 = s.get_loops(block) + info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) + assert info is not None + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) - for i in range(3): - assert desc_loops[i] in desc_loop_to_sref + for i in range(3): + assert desc_loops[i] in desc_loop_to_sref - assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) - assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) - assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) + assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) if __name__ == "__main__":