Skip to content

Commit

Permalink
more hack to tensorize loop mapping to make resnet50 e2e work
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 12, 2022
1 parent 2409674 commit db34397
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
12 changes: 2 additions & 10 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include <tvm/meta_schedule/tune_context.h>
#include <tvm/node/node.h>
#include <tvm/node/serialization.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/support/parallel_for.h>
#include <tvm/tir/schedule/schedule.h>

Expand Down Expand Up @@ -308,19 +307,12 @@ struct ThreadedTraceApply {
/*rand_state=*/ForkSeed(rand_state),
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);

trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
sch->EnterPostproc();

for (int i = 0; i < n_; ++i) {
Item& item = items_[i];
try {
if (!item.postproc->Apply(sch)) {
++item.fail_counter;
return NullOpt;
}
} catch (const std::exception& e) {
LOG(WARNING) << "ThreadedTraceApply::Apply failed with error " << e.what();
if (!item.postproc->Apply(sch)) {
++item.fail_counter;
return NullOpt;
}
}
Expand Down
21 changes: 14 additions & 7 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2065,6 +2065,10 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
// ^ i_block
// desc_block( v1=..., v2=...)
// ^ i_desc

std::vector<IterVarType> iter_types = GetBlockVarTypes(block_sref);
ICHECK(block_loops.size() == iter_types.size());

for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) {
// For each block var binding, we find
const PrimExpr& block_bind = block->iter_values[i_block];
Expand All @@ -2073,20 +2077,22 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
LOG(INFO) << "desc bind: " << desc_bind;
// Step 4.1. Find the corresponding loop of the i-th block var of block
const tir::ForNode* block_loop = nullptr;
for (int i = 0, n = block_loops.size(); i < n; ++i) {
for (int i = block_loops.size() - 1; i >= 0; --i) {
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
const auto* int_block_extent = block_loops[i]->extent.as<IntImmNode>();
const auto* int_desc_extent = desc_loops[i_desc]->extent.as<IntImmNode>();

if ((i_desc == 0 && int_block_extent->value == int_desc_extent->value) || !tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) {
return block_loop_vars.count(var);
})) {
if (i_desc != n_desc_vars - 1 && iter_types[i] == IterVarType::kCommReduce) continue;

if (int_block_extent->value == int_desc_extent->value) {
block_loop = block_loops[i];
LOG(INFO) << "Selected " << i << " th block loop " << block_loops[i]->loop_var << ", " << block_loop->extent;
LOG(INFO) << "Selected " << i << " th block loop " << block_loops[i]->loop_var << ", "
<< block_loop->extent;
break;
} else {
LOG(INFO) << i << " th block loop not ok " << ", " << block_loops[i]->loop_var << ", " << block_loops[i]->extent;
LOG(INFO) << i << " th block loop not ok "
<< ", " << block_loops[i]->loop_var << ", " << block_loops[i]->extent;
}
}
if (block_loop == nullptr) {
Expand All @@ -2101,7 +2107,8 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
return desc_loop_vars.count(var);
})) {
desc_loop = desc_loops[i];
LOG(INFO) << "Selected " << i << " th desc loop " << desc_loop->extent;;
LOG(INFO) << "Selected " << i << " th desc loop " << desc_loop->extent;
;
break;
}
}
Expand Down

0 comments on commit db34397

Please sign in to comment.