Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Oct 24, 2020
1 parent 4aac47f commit c784fba
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -923,12 +923,17 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape,
placeholder_op.as<te::PlaceholderOpNode>()->dtype);
} else if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) {
std::unordered_map<std::string, PrimExpr> axes_stride;
for (const auto& i : origin_axes) {
axes_stride[i] = Integer(1);
}
Array<PrimExpr> new_stride(new_shape.size(), PrimExpr());
PrimExpr temp = Integer(1);
for (int i = new_shape.size() - 1; i >= 0; i--) {
new_stride.Set(i, temp);
temp *= new_shape[i];
new_stride.Set(i, axes_stride[new_axes[i]]);
axes_stride[new_axes[i]] *= new_shape[i];
}

// Add extra layout transpose stage
const auto& layout_transform_tensor = te::compute(new_shape,
[&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, &new_axes]
Expand Down Expand Up @@ -963,6 +968,12 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
transform_steps->Set(i, std::move(step));
}
}
Array<Integer> to_fuse;
for (size_t i = 0; i < new_shape.size() - 1; i++) {
to_fuse.push_back(i);
}
transform_steps->push_back(FuseStep(stage_id, to_fuse));
transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
} else {
LOG(FATAL) << "Call ComputeDAG::RewriteLayout with NoRewrite.";
}
Expand Down Expand Up @@ -1090,7 +1101,6 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) {
Array<Step> steps = transform_steps;
const auto& dag = RewriteLayout(&steps, layout_rewrite);
LOG(INFO) << dag;
return dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite);
}

Expand Down

0 comments on commit c784fba

Please sign in to comment.