Skip to content

Commit

Permalink
CI Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Oct 26, 2020
1 parent 1d21ba1 commit f819b60
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,8 @@ std::string GetOrigLayout(std::set<std::string>* placeholder_axis_names, const t
return orig_layout;
}

std::string GetNewLayout(const State& state, const int stage_id,
const Stage& stage, const te::Operation& op, const te::Tensor& placeholder,
std::string GetNewLayout(const State& state, const int stage_id, const Stage& stage,
const te::Operation& op, const te::Tensor& placeholder,
const std::set<std::string>& placeholder_axis_names) {
std::ostringstream os;
Array<Iterator> stage_iters;
Expand Down Expand Up @@ -911,8 +911,8 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
std::vector<std::string> origin_axes;
ParseKernelLayout(origin_layout, &origin_shape, &origin_axes);

std::string new_layout = GetNewLayout(state, stage_id, stage, op, placeholder,
placeholder_axis_names);
std::string new_layout =
GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names);
Array<PrimExpr> new_shape;
std::vector<std::string> new_axes;
ParseKernelLayout(new_layout, &new_shape, &new_axes);
Expand All @@ -935,9 +935,10 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
}

// Add extra layout transpose stage
const auto& layout_transform_tensor = te::compute(new_shape,
[&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, &new_axes]
(const tvm::runtime::Array<tvm::tir::Var>& indices) -> tvm::PrimExpr {
const auto& layout_transform_tensor = te::compute(
new_shape,
[&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes,
&new_axes](const tvm::runtime::Array<tvm::tir::Var>& indices) -> tvm::PrimExpr {
Array<PrimExpr> access_indices;
for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) {
PrimExpr temp = Integer(0);
Expand All @@ -949,7 +950,8 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
access_indices.push_back(temp);
}
return placeholder_op.output(0)(access_indices);
}, "auto_schedule_layout_transpose");
},
"auto_schedule_layout_transpose");
new_op_to_update = layout_transform_tensor->op;

// Update the transform steps
Expand Down

0 comments on commit f819b60

Please sign in to comment.