Skip to content

Commit

Permalink
resolve review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
minminsun committed Aug 26, 2020
1 parent ea60898 commit fc7be49
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 153 deletions.
3 changes: 3 additions & 0 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def apply_steps_from_state(self, state, layout_rewrite=False):
state : Union[State, StateObject]
The state from which we get transform steps.
layout_rewrite: Bool
Rewrite the layout of placeholder.
Returns
-------
A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
Expand Down
276 changes: 123 additions & 153 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,78 +666,46 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}

/*!
* \brief utility function for kernel_layout_transform
*/
inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
std::vector<std::string>* axes) {
int32_t factor = 0;
std::string axis = "";
for (char c : std::string(layout)) {
if (c >= 'A' && c <= 'z') {
axis += c;
if (factor != 0) {
shape->push_back(factor);
factor = 0;
}
} else if (c >= '0' && c <= '9') {
factor = factor * 10 + c - '0';
if (!axis.empty()) {
axes->push_back(axis);
axis = "";
}
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
if (!axis.empty()) {
axes->push_back(axis);
}
}

std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }

class IndexRewriter : public StmtExprMutator {
public:
IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
: placeholder_op_(placeholder_op), new_layout_(new_layout) {}
: placeholder_op_(placeholder_op) {
ParseKernelLayout(new_layout, &new_shape_, &new_names_);
}

PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }

PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
te::Tensor t = Downcast<te::Tensor>(op->producer);
if (t->op == placeholder_op_) {
Array<PrimExpr> new_shape;
std::vector<std::string> new_names;
parse_kernel_layout(new_layout_, &new_shape, &new_names);
std::unordered_map<std::string, PrimExpr> name_to_arg;
for (const auto& arg : op->indices) {
std::string axis_name;
if (const auto* pimm = arg.as<IntImmNode>()) {
CHECK_EQ(pimm->value, 0);
axis_name = "IntImm";
} else {
axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
axis_name = AxisBaseName(CleanName(Downcast<Var>(arg)->name_hint));
CHECK_EQ(name_to_arg.count(axis_name), 0);
name_to_arg[axis_name] = arg;
}
}

std::unordered_map<std::string, PrimExpr> div_factors;
std::vector<PrimExpr> r_new_args;
for (int i = new_names.size() - 1; i >= 0; --i) {
auto ori_iter_name = new_names[i];
for (int i = new_names_.size() - 1; i >= 0; --i) {
auto ori_iter_name = new_names_[i];
auto name_it = name_to_arg.find(ori_iter_name);
CHECK(name_it != name_to_arg.end());
PrimExpr ori_arg = name_it->second;

PrimExpr mod_factor = new_shape[i];
PrimExpr mod_factor = new_shape_[i];

PrimExpr div_factor = 1;
if (div_factors.count(ori_iter_name)) {
div_factor = div_factors[ori_iter_name];
}
div_factors[ori_iter_name] = div_factor * new_shape[i];
div_factors[ori_iter_name] = div_factor * new_shape_[i];

PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);

Expand All @@ -753,7 +721,8 @@ class IndexRewriter : public StmtExprMutator {

private:
const te::Operation& placeholder_op_;
const std::string& new_layout_;
Array<PrimExpr> new_shape_;
std::vector<std::string> new_names_;
};

std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
Expand All @@ -774,7 +743,7 @@ std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const
CHECK_EQ(pimm->value, 0);
axis_name = "IntImm";
} else {
axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
axis_name = AxisBaseName(CleanName(Downcast<Var>(e)->name_hint));
}

placeholder_axis_names->insert(axis_name);
Expand Down Expand Up @@ -833,7 +802,7 @@ std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const
// fused iters have been replaced with iter->ori_iters.
// So there should be only one ori iter name extracted from iter->name.
CHECK_EQ(ori_iter_names.size(), 1);
auto ori_iter_name = BaseName(*ori_iter_names.begin());
auto ori_iter_name = AxisBaseName(*ori_iter_names.begin());
new_axis_names.push_back(ori_iter_name);
}
for (size_t i = 0; i < new_axis_names.size(); ++i) {
Expand Down Expand Up @@ -868,135 +837,136 @@ void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
for (const auto& stage : state->stages) {
stage_id += 1;
const te::Operation& op = stage->op;
if (op->IsInstance<te::ComputeOpNode>()) {
const Map<String, ObjectRef>& attrs = op->attrs;
if (attrs.count(layout_free_placeholders_key)) {
const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
Array<te::Tensor> placeholders = Downcast<Array<te::Tensor>>(attr_value);
for (const auto& placeholder : placeholders) {
const auto& placeholder_op = placeholder->op;

// Check whether this placeholder has already been handled
if (handled_ops.count(placeholder_op)) {
continue;
}
if (!op->IsInstance<te::ComputeOpNode>()) {
continue;
}
const Map<String, ObjectRef>& attrs = op->attrs;
if (attrs.count(layout_free_placeholders_key) == 0) {
continue;
}
const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
Array<te::Tensor> placeholders = Downcast<Array<te::Tensor>>(attr_value);
for (const auto& placeholder : placeholders) {
const auto& placeholder_op = placeholder->op;

// Check whether this placeholder has already been handled
if (handled_ops.count(placeholder_op)) {
continue;
}

// Skip the op that is not direct consumer of this placeholder.
// This is usually caused by cache read/write.
bool direct_consumer = false;
for (auto& t : op->InputTensors()) {
if (t->op == placeholder_op) {
direct_consumer = true;
break;
}
}
if (!direct_consumer) {
continue;
}

std::set<std::string> placeholder_axis_names;
get_ori_layout(&placeholder_axis_names, op, placeholder);

Array<PrimExpr> new_shape;
std::string new_layout = get_new_layout(&new_shape, state, stage_id, stage, op, placeholder,
placeholder_axis_names);

handled_ops.insert(placeholder_op);

Array<te::Operation> old_ops = pdag->ops;
ArrayNode* pops = pdag->ops.CopyOnWrite();

// skip the op that is not direct consumer of this placeholder,
// mostly due to cache read/write.
bool direct_consumer = false;
// Create new placeholder
te::Operation new_placeholder_op;
new_placeholder_op = te::PlaceholderOp(placeholder_op->name, new_shape,
placeholder_op.as<te::PlaceholderOpNode>()->dtype);

te::Operation new_compute_op, old_compute_op;
Array<PrimExpr> new_body;
IndexRewriter index_rewriter(placeholder_op, new_layout);
for (auto& op : old_ops) {
if (auto* pop = op.as<te::ComputeOpNode>()) {
bool need_update = false;
for (auto& t : op->InputTensors()) {
if (t->op == placeholder_op) {
direct_consumer = true;
need_update = true;
break;
}
}
if (!direct_consumer) {
continue;
}

std::set<std::string> placeholder_axis_names;
get_ori_layout(&placeholder_axis_names, op, placeholder);

Array<PrimExpr> new_shape;
std::string new_layout = get_new_layout(&new_shape, state, stage_id, stage, op,
placeholder, placeholder_axis_names);

handled_ops.insert(placeholder_op);

Array<te::Operation> old_ops = pdag->ops;
ArrayNode* pops = pdag->ops.CopyOnWrite();

// Create new placeholder
te::Operation new_placeholder_op;
new_placeholder_op = te::PlaceholderOp(placeholder_op->name, new_shape,
placeholder_op.as<te::PlaceholderOpNode>()->dtype);

te::Operation new_compute_op, old_compute_op;
Array<PrimExpr> new_body;
IndexRewriter index_rewriter(placeholder_op, new_layout);
for (auto& op : old_ops) {
if (auto* pop = op.as<te::ComputeOpNode>()) {
bool need_update = false;
for (auto& t : op->InputTensors()) {
if (t->op == placeholder_op) {
need_update = true;
break;
}
}
if (need_update) {
for (auto& body : pop->body) {
new_body.push_back(index_rewriter.Rewrite(body));
}
old_compute_op = op;
CHECK(!new_compute_op.defined());
new_compute_op =
te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body);
}
}
}

// construct the map from old_op to new_op
std::unordered_map<te::Operation, te::Operation> updated_ops;
for (size_t i = 0; i < old_ops.size(); ++i) {
auto old_op = old_ops[i];
if (old_op == placeholder_op) {
pops->SetItem(i, new_placeholder_op);
updated_ops[placeholder_op] = new_placeholder_op;
} else if (old_op == old_compute_op) {
pops->SetItem(i, new_compute_op);
updated_ops[old_compute_op] = new_compute_op;
} else {
pops->SetItem(i, old_op);
}
}

// Because ops is sorted in topo-order, only do one pass linear scan here.
for (size_t i = 0; i < pops->size(); ++i) {
auto old_op = Downcast<te::Operation>(pops->at(i));
if (auto* pop = old_op.as<te::ComputeOpNode>()) {
auto inputs = pop->InputTensors();
std::unordered_map<te::Tensor, te::Tensor> rmap;
for (auto input : inputs) {
auto it = updated_ops.find(input->op);
te::Operation new_op;
while (it != updated_ops.end()) {
new_op = it->second;
it = updated_ops.find(new_op);
}
if (new_op.defined()) {
int index = input->value_index;
rmap[input] = new_op.output(index);
}
}
if (!rmap.empty()) {
te::Operation new_op = pop->ReplaceInputs(old_op, rmap);
updated_ops[old_op] = new_op;
pops->SetItem(i, new_op);
}
if (need_update) {
for (auto& body : pop->body) {
new_body.push_back(index_rewriter.Rewrite(body));
}
old_compute_op = op;
CHECK(!new_compute_op.defined());
new_compute_op = te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body);
}
}
}

pdag->init_state = State(pdag->ops);

Array<te::Tensor> old_tensors = pdag->tensors;
ArrayNode* ptensors = pdag->tensors.CopyOnWrite();
// construct the map from old_op to new_op
std::unordered_map<te::Operation, te::Operation> updated_ops;
for (size_t i = 0; i < old_ops.size(); ++i) {
auto old_op = old_ops[i];
if (old_op == placeholder_op) {
pops->SetItem(i, new_placeholder_op);
updated_ops[placeholder_op] = new_placeholder_op;
} else if (old_op == old_compute_op) {
pops->SetItem(i, new_compute_op);
updated_ops[old_compute_op] = new_compute_op;
} else {
pops->SetItem(i, old_op);
}
}

for (size_t i = 0; i < old_tensors.size(); ++i) {
const auto& old_tensor = old_tensors[i];
auto it = updated_ops.find(old_tensor->op);
// Because ops is sorted in topo-order, only do one pass linear scan here.
for (size_t i = 0; i < pops->size(); ++i) {
auto old_op = Downcast<te::Operation>(pops->at(i));
if (auto* pop = old_op.as<te::ComputeOpNode>()) {
auto inputs = pop->InputTensors();
std::unordered_map<te::Tensor, te::Tensor> rmap;
for (auto input : inputs) {
auto it = updated_ops.find(input->op);
te::Operation new_op;
while (it != updated_ops.end()) {
new_op = it->second;
it = updated_ops.find(new_op);
}
if (new_op.defined()) {
auto index = old_tensor->value_index;
ptensors->SetItem(i, new_op.output(index));
int index = input->value_index;
rmap[input] = new_op.output(index);
}
}
} // end for placeholder
if (!rmap.empty()) {
te::Operation new_op = pop->ReplaceInputs(old_op, rmap);
updated_ops[old_op] = new_op;
pops->SetItem(i, new_op);
}
}
}

pdag->init_state = State(pdag->ops);

Array<te::Tensor> old_tensors = pdag->tensors;
ArrayNode* ptensors = pdag->tensors.CopyOnWrite();

for (size_t i = 0; i < old_tensors.size(); ++i) {
const auto& old_tensor = old_tensors[i];
auto it = updated_ops.find(old_tensor->op);
te::Operation new_op;
while (it != updated_ops.end()) {
new_op = it->second;
it = updated_ops.find(new_op);
}
if (new_op.defined()) {
auto index = old_tensor->value_index;
ptensors->SetItem(i, new_op.output(index));
}
}
} // end for compute op
} // end for placeholder
} // end for stage
}

Expand Down
Loading

0 comments on commit fc7be49

Please sign in to comment.