Skip to content

Commit

Permalink
bring back logic in original code to support loop permutation
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 19, 2022
1 parent ec39b62 commit 9ec0974
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
59 changes: 48 additions & 11 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/container/optional.h>
#include <tvm/tir/expr.h>

#include "../utils.h"
Expand Down Expand Up @@ -2106,22 +2107,58 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,

int next_block_ind = block_loops.size() - 1;
for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
const tir::ForNode* desc_loop = desc_loops[i_desc];
// Step 4.2. Find the corresponding loop of the i-th block var of desc
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
const tir::ForNode* desc_loop = nullptr;
IterVarType iter_type_desc;
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
if (!UsesVar(r,
[&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) {
desc_loop = desc_loops[i];
iter_type_desc = iter_types_desc[i];
break;
}
}
if (desc_loop == nullptr || desc_loop->extent.as<IntImmNode>() == nullptr) {
return NullOpt;
}

const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();
if (!int_desc_extent) continue;

const tir::ForNode* block_loop = nullptr;

PrimExpr block_bind;
for (int i_block = next_block_ind; i_block >= 0; --i_block) {
const tir::ForNode* block_loop = block_loops[i_block];
const IntImmNode* int_block_extent = block_loop->extent.as<IntImmNode>();
if (iter_types_block[i_block] == iter_type_desc) {
next_block_ind = i_block - 1;
block_bind = block->iter_values[i_block];
break;
}
}

if (!int_block_extent) continue;
if (int_block_extent->value % int_desc_extent->value != 0) continue;
if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue;
for (int i = 0, n = block_loops.size(); i < n; ++i) {
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
if (!UsesVar(r,
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) {
block_loop = block_loops[i];
const IntImmNode* int_block_extent = block_loop->extent.as<IntImmNode>();

const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
next_block_ind = i_block - 1;
break;
if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) {
return NullOpt;
}

const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
auto it = ret->loop_map.find(block_loop_sref);
if (it == ret->loop_map.end()) {
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
} else if ((*it).second.get() != desc_loop) {
return NullOpt;
}

break;
}
}
}

Expand Down
24 changes: 14 additions & 10 deletions tests/python/unittest/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,20 +240,24 @@ def matmul_16x16x16xf16f16f16_desc(

s = Schedule(matmul)
block = s.get_block("C")
i0, i1, i2 = s.get_loops(block)
desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc)

info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc)

desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())
for do_reorder in [True, False]:
# Mapping should be invariant to the loop permutation
if do_reorder:
s.reorder(i2, i0, i1)

desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc)
i0, i1, i2 = s.get_loops(block)
info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc)
assert info is not None
desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

for i in range(3):
assert desc_loops[i] in desc_loop_to_sref
for i in range(3):
assert desc_loops[i] in desc_loop_to_sref

assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0)
assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1)
assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0)
assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1)
assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)


if __name__ == "__main__":
Expand Down

0 comments on commit 9ec0974

Please sign in to comment.