Skip to content

Commit

Permalink
vectorized init loop
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent fcc31ee commit dc9f71d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/meta_schedule/postproc/rewrite_vnni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/
#include "../utils.h"
#include "tvm/runtime/container/base.h"

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -52,6 +53,13 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name,
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (block_name.find("init") == std::string::npos) {
tasks.push_back(std::make_tuple(block_name, func_name, intrin_name.value()));
} else {
BlockRV init_block = sch->GetBlock(block_name, func_name);
Array<BlockRV> child_blocks = sch->GetChildBlocks(init_block);
ICHECK(child_blocks.size() == 1);
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
ICHECK(init_loops.size() == 1);
sch->Vectorize(init_loops[0]);
}
}
}
Expand All @@ -61,7 +69,6 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name,
}

bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
LOG(INFO) << "Apply RewriteVNNI " << sch->mod();
std::vector<BlockPosition> tasks;
for (const auto& kv : sch->mod()->functions) {
GlobalVar g_var = kv.first;
Expand All @@ -76,13 +83,11 @@ bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
String intrin_name = std::get<2>(task);
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize);
sch->Tensorize(block_rv, intrin_name);
LOG(INFO) << "After tensorize: " << sch->mod();
}
return true;
}

Postproc RewriteVNNI() {
LOG(INFO) << "RewriteVNNI is called";
ObjectPtr<RewriteVNNINode> n = make_object<RewriteVNNINode>();
return Postproc(n);
}
Expand Down

0 comments on commit dc9f71d

Please sign in to comment.