diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 768cdf099c7ce..a62aac3090dbf 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -923,12 +923,17 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape, placeholder_op.as()->dtype); } else if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { + std::unordered_map axes_stride; + for (const auto& i : origin_axes) { + axes_stride[i] = Integer(1); + } Array new_stride(new_shape.size(), PrimExpr()); PrimExpr temp = Integer(1); for (int i = new_shape.size() - 1; i >= 0; i--) { - new_stride.Set(i, temp); - temp *= new_shape[i]; + new_stride.Set(i, axes_stride[new_axes[i]]); + axes_stride[new_axes[i]] *= new_shape[i]; } + // Add extra layout transpose stage const auto& layout_transform_tensor = te::compute(new_shape, [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, &new_axes] @@ -963,6 +968,12 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, transform_steps->Set(i, std::move(step)); } } + Array to_fuse; + for (size_t i = 0; i < new_shape.size() - 1; i++) { + to_fuse.push_back(i); + } + transform_steps->push_back(FuseStep(stage_id, to_fuse)); + transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel)); } else { LOG(FATAL) << "Call ComputeDAG::RewriteLayout with NoRewrite."; } @@ -1090,7 +1101,6 @@ std::pair> ComputeDAG::ApplySteps( if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) { Array steps = transform_steps; const auto& dag = RewriteLayout(&steps, layout_rewrite); - LOG(INFO) << dag; return dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite); }