Skip to content

Commit

Permalink
fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
minminsun committed Aug 26, 2020
1 parent 809998d commit ea60898
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 112 deletions.
10 changes: 5 additions & 5 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ class ComputeDAG : public ObjectRef {
* according to the loop nest derived with `transform_steps`.
* \param transform_steps Transform steps of a state.
*/
void RewriteLayout(
const Array<Step> &transform_steps);
void RewriteLayout(const Array<Step>& transform_steps);

/*!
* \brief Apply the history transform steps to get a TVM schedule.
Expand All @@ -224,9 +223,10 @@ class ComputeDAG : public ObjectRef {
* \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower`
* or `tvm.build`.
*/
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr, bool layout_rewrite = false) const;
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(const Array<Step>& transform_steps,
Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr,
bool layout_rewrite = false) const;

/*!
* \brief Print transform steps as equivalent python schedule API.
Expand Down
1 change: 1 addition & 0 deletions include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include <dmlc/json.h>
#include <tvm/node/node.h>
#include <tvm/te/schedule.h>

#include <vector>

namespace tvm {
Expand Down
184 changes: 83 additions & 101 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
#include <vector>

#include "../arith/pattern_match.h"
#include "utils.h"
#include "search_policy/utils.h"
#include "utils.h"

namespace tvm {
namespace auto_scheduler {
Expand Down Expand Up @@ -669,8 +669,7 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
/*!
* \brief utility function for kernel_layout_transform
*/
inline void parse_kernel_layout(const String& layout,
Array<PrimExpr>* shape,
inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
std::vector<std::string>* axes) {
int32_t factor = 0;
std::string axis = "";
Expand All @@ -696,20 +695,14 @@ inline void parse_kernel_layout(const String& layout,
}
}

std::string BaseName(const std::string& str) {
return str.substr(0, str.rfind("_"));
}
std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }

class IndexRewriter : public StmtExprMutator {
public:
IndexRewriter(const te::Operation& placeholder_op,
const std::string& new_layout):
placeholder_op_(placeholder_op),
new_layout_(new_layout) {}
IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
: placeholder_op_(placeholder_op), new_layout_(new_layout) {}

PrimExpr Rewrite(PrimExpr expr) {
return this->VisitExpr(expr);
}
PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }

PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
te::Tensor t = Downcast<te::Tensor>(op->producer);
Expand All @@ -721,7 +714,7 @@ class IndexRewriter : public StmtExprMutator {
for (const auto& arg : op->indices) {
std::string axis_name;
if (const auto* pimm = arg.as<IntImmNode>()) {
CHECK_EQ(pimm->value, 0);
CHECK_EQ(pimm->value, 0);
axis_name = "IntImm";
} else {
axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
Expand Down Expand Up @@ -763,8 +756,7 @@ class IndexRewriter : public StmtExprMutator {
const std::string& new_layout_;
};

std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
const te::Operation& op,
std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
const te::Tensor& placeholder) {
ReadAccessExtractor extractor;
for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
Expand Down Expand Up @@ -798,79 +790,74 @@ std::string get_ori_layout(std::set<std::string>* placeholder_axis_names,
return ori_layout;
}

std::string get_new_layout(Array<PrimExpr>* new_shape,
const State& state,
const int stage_id,
const Stage& stage,
const te::Operation& op,
std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
const Stage& stage, const te::Operation& op,
const te::Tensor& placeholder,
const std::set<std::string>& placeholder_axis_names) {
std::ostringstream os;
Array<Iterator> stage_iters;

auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
int attach_pos = -1;
size_t iters_before_attach = 0;
if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
auto attach = attach_it->second;
const auto& attach_stage = state->stages[attach.first];
attach_pos = attach.second;
stage_iters.insert(stage_iters.end(),
attach_stage->iters.begin(),
attach_stage->iters.begin() + attach_pos + 1);
}
std::ostringstream os;
Array<Iterator> stage_iters;

auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
int attach_pos = -1;
size_t iters_before_attach = 0;
if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
auto attach = attach_it->second;
const auto& attach_stage = state->stages[attach.first];
attach_pos = attach.second;
stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
attach_stage->iters.begin() + attach_pos + 1);
}

stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());

std::vector<Iterator> iters;
for (size_t i = 0; i < stage_iters.size(); ++i) {
const auto& iter = stage_iters[i];
if (iter->ori_iters.empty()) {
iters.push_back(iter);
} else {
for (const Iterator& ori_iter : iter->ori_iters) {
iters.push_back(ori_iter);
}
}
if (static_cast<int>(i) == attach_pos) {
iters_before_attach = iters.size();
std::vector<Iterator> iters;
for (size_t i = 0; i < stage_iters.size(); ++i) {
const auto& iter = stage_iters[i];
if (iter->ori_iters.empty()) {
iters.push_back(iter);
} else {
for (const Iterator& ori_iter : iter->ori_iters) {
iters.push_back(ori_iter);
}
}
if (static_cast<int>(i) == attach_pos) {
iters_before_attach = iters.size();
}
}

std::vector<std::string> new_names;
std::vector<std::string> new_axis_names;
for (const Iterator& iter : iters) {
std::set<std::string> ori_iter_names;
ExtractOriginalIterators(iter->name, &ori_iter_names);
// fused iters have been replaced with iter->ori_iters.
// So there should be only one ori iter name extracted from iter->name.
CHECK_EQ(ori_iter_names.size(), 1);
auto ori_iter_name = BaseName(*ori_iter_names.begin());
new_axis_names.push_back(ori_iter_name);
std::vector<std::string> new_names;
std::vector<std::string> new_axis_names;
for (const Iterator& iter : iters) {
std::set<std::string> ori_iter_names;
ExtractOriginalIterators(iter->name, &ori_iter_names);
// fused iters have been replaced with iter->ori_iters.
// So there should be only one ori iter name extracted from iter->name.
CHECK_EQ(ori_iter_names.size(), 1);
auto ori_iter_name = BaseName(*ori_iter_names.begin());
new_axis_names.push_back(ori_iter_name);
}
for (size_t i = 0; i < new_axis_names.size(); ++i) {
auto iter = iters[i];
std::string ori_iter_name;
if (i < iters_before_attach) {
ori_iter_name = new_axis_names[i + iters_before_attach];
} else {
ori_iter_name = new_axis_names[i];
}
for (size_t i = 0; i < new_axis_names.size(); ++i) {
auto iter = iters[i];
std::string ori_iter_name;
if (i < iters_before_attach) {
ori_iter_name = new_axis_names[i + iters_before_attach];
} else {
ori_iter_name = new_axis_names[i];
}
if (placeholder_axis_names.count(ori_iter_name)) {
os << iter->range->extent << ori_iter_name;
new_names.push_back(ori_iter_name);
new_shape->push_back(iter->range->extent);
}
if (placeholder_axis_names.count(ori_iter_name)) {
os << iter->range->extent << ori_iter_name;
new_names.push_back(ori_iter_name);
new_shape->push_back(iter->range->extent);
}
std::string new_layout = os.str();
os.str("");
// TODO(minmin): uncomment this line for relay integration
// ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
return new_layout;
}
std::string new_layout = os.str();
os.str("");
// TODO(minmin): uncomment this line for relay integration
// ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
return new_layout;
}

void ComputeDAG::RewriteLayout(
const Array<Step> &transform_steps) {
void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
ComputeDAGNode* pdag = this->CopyOnWrite();
auto node = make_object<StateNode>();
node->transform_steps = transform_steps;
Expand Down Expand Up @@ -911,8 +898,8 @@ void ComputeDAG::RewriteLayout(
get_ori_layout(&placeholder_axis_names, op, placeholder);

Array<PrimExpr> new_shape;
std::string new_layout = get_new_layout(&new_shape, state, stage_id, stage,
op, placeholder, placeholder_axis_names);
std::string new_layout = get_new_layout(&new_shape, state, stage_id, stage, op,
placeholder, placeholder_axis_names);

handled_ops.insert(placeholder_op);

Expand All @@ -921,16 +908,13 @@ void ComputeDAG::RewriteLayout(

// Create new placeholder
te::Operation new_placeholder_op;
new_placeholder_op =
te::PlaceholderOp(placeholder_op->name,
new_shape,
placeholder_op.as<te::PlaceholderOpNode>()->dtype);
new_placeholder_op = te::PlaceholderOp(placeholder_op->name, new_shape,
placeholder_op.as<te::PlaceholderOpNode>()->dtype);

te::Operation new_compute_op, old_compute_op;
Array<PrimExpr> new_body;
IndexRewriter index_rewriter(placeholder_op,
new_layout);
for (auto& op : old_ops) {
Array<PrimExpr> new_body;
IndexRewriter index_rewriter(placeholder_op, new_layout);
for (auto& op : old_ops) {
if (auto* pop = op.as<te::ComputeOpNode>()) {
bool need_update = false;
for (auto& t : op->InputTensors()) {
Expand All @@ -945,8 +929,8 @@ void ComputeDAG::RewriteLayout(
}
old_compute_op = op;
CHECK(!new_compute_op.defined());
new_compute_op = te::ComputeOp(
pop->name, pop->tag, pop->attrs, pop->axis, new_body);
new_compute_op =
te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body);
}
}
}
Expand Down Expand Up @@ -999,7 +983,7 @@ void ComputeDAG::RewriteLayout(

for (size_t i = 0; i < old_tensors.size(); ++i) {
const auto& old_tensor = old_tensors[i];
auto it = updated_ops.find(old_tensor->op);
auto it = updated_ops.find(old_tensor->op);
te::Operation new_op;
while (it != updated_ops.end()) {
new_op = it->second;
Expand All @@ -1013,12 +997,12 @@ void ComputeDAG::RewriteLayout(
} // end for placeholder
}
} // end for compute op
} // end for stage
} // end for stage
}

std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes, bool layout_rewrite) const {
const Array<Step>& transform_steps, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
bool layout_rewrite) const {
if (layout_rewrite && !transform_steps.empty()) {
ComputeDAG new_dag = *this;
new_dag.RewriteLayout(transform_steps);
Expand Down Expand Up @@ -1153,9 +1137,8 @@ State ComputeDAG::InferBound(const State& state) const {

auto find_res = bounds.find(axis);
if (find_res != bounds.end()) {
new_iters.push_back(
Iterator(iter->name, (*find_res).second, iter->iter_kind, iter->annotation,
&iter->ori_iters));
new_iters.push_back(Iterator(iter->name, (*find_res).second, iter->iter_kind,
iter->annotation, &iter->ori_iters));
} else {
LOG(FATAL) << "Infer bound fails";
}
Expand Down Expand Up @@ -1307,12 +1290,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG").set_body_typed([](Array<te::Ten
});

TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState")
.set_body_typed([](const ComputeDAG& dag, const State& state,
const bool layout_rewrite) {
.set_body_typed([](const ComputeDAG& dag, const State& state, const bool layout_rewrite) {
te::Schedule sch;
Array<te::Tensor> return_tensors;
std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps,
nullptr, nullptr, layout_rewrite);
std::tie(sch, return_tensors) =
dag.ApplySteps(state->transform_steps, nullptr, nullptr, layout_rewrite);
return Array<ObjectRef>{sch, return_tensors};
});

Expand Down
3 changes: 1 addition & 2 deletions src/auto_scheduler/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ TVM_REGISTER_NODE_TYPE(StateNode);
TVM_REGISTER_NODE_TYPE(IteratorNode);

/********** Iterator **********/
Iterator::Iterator(String name, Range range, IteratorKind iter_kind,
IteratorAnnotation annotation,
Iterator::Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation,
const std::vector<Iterator>* ori_iters) {
auto node = make_object<IteratorNode>();
node->name = std::move(name);
Expand Down
6 changes: 2 additions & 4 deletions src/auto_scheduler/transform_step.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1240,8 +1240,7 @@ void ComputeAtStepNode::ApplyToState(State* state) const {
// compute at
Array<Iterator> new_iters;
for (const Iterator& it : stage->iters) {
new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation,
&it->ori_iters));
new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation, &it->ori_iters));
}

StateNode* pstate = state->CopyOnWrite();
Expand Down Expand Up @@ -1357,8 +1356,7 @@ void ComputeRootStepNode::ApplyToState(State* state) const {
// compute root
Array<Iterator> new_iters;
for (const Iterator& it : stage->iters) {
new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation,
&it->ori_iters));
new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation, &it->ori_iters));
}

StateNode* pstate = state->CopyOnWrite();
Expand Down

0 comments on commit ea60898

Please sign in to comment.