From eb5a73ba6490452bd41a495f579438fec5dee4ae Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 15 Oct 2020 16:09:14 +0800 Subject: [PATCH 01/14] Add pre transpose support for layout rewrite --- include/tvm/auto_scheduler/compute_dag.h | 26 ++++-- include/tvm/auto_scheduler/transform_step.h | 2 +- python/tvm/auto_scheduler/compute_dag.py | 7 +- src/auto_scheduler/compute_dag.cc | 89 ++++++++++++++++----- src/auto_scheduler/loop_state.cc | 6 ++ 5 files changed, 104 insertions(+), 26 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 6e67fef0f283..8bf167e380dc 100755 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -194,6 +194,22 @@ class ComputeDAGNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); }; +/*! + * \brief Several options for applying layout rewrite. + * This is a optimization to rewrite the shape of input tensor according to the schedule we get. + */ +enum class LayoutRewriteOption : int { + /*! \brief Do not process layout rewrite. */ + NoRewrite = 0, + /*! + * \brief Modify the placeholder to suit the schedule. + * \note This should be used along with the graph optimization in Relay. + */ + RewriteWithPlaceholder = 1, + /*! \brief Insert a pre-transpose stage between placeholer and compute op to suit the schedule. */ + RewriteWithPreTranspose = 2 +}; + /*! * \brief Managed reference to ComputeDAGNode. * \sa ComputeDAGNode @@ -215,7 +231,7 @@ class ComputeDAG : public ObjectRef { * according to the loop nest derived with `transform_steps`. * \param transform_steps Transform steps of a state. */ - void RewriteLayout(const Array& transform_steps); + void RewriteLayout(Array* transform_steps, LayoutRewriteOption layout_rewrite); /*! * \brief Apply the history transform steps to get a TVM schedule. @@ -229,10 +245,10 @@ class ComputeDAG : public ObjectRef { * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower` * or `tvm.build`. */ - std::pair> ApplySteps(const Array& transform_steps, - Array* stages = nullptr, - StageToAxesMap* stage_to_axes = nullptr, - bool layout_rewrite = false) const; + std::pair> ApplySteps( + const Array& transform_steps, Array* stages = nullptr, + StageToAxesMap* stage_to_axes = nullptr, + LayoutRewriteOption layout_rewrite = LayoutRewriteOption::NoRewrite) const; /*! * \brief Print transform steps as equivalent python schedule API. diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index 7be3554c7c5d..94439d72db9c 100755 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -846,7 +846,7 @@ class ComputeAtStep : public Step { */ explicit ComputeAtStep(dmlc::JSONReader* reader); - TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); }; /*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 4b1b264c30d8..02ac05191585 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -50,6 +50,11 @@ class ComputeDAG(Object): compute : Union[List[Tensor], str, Schedule] Input/output tensors or workload key for a compute declaration. """ + LAYOUT_REWRITE_TABLE = { + "NoRewrite": 0, + "RewriteWithPlaceholder": 1, + "RewriteWithPreTranspose": 2, + } def __init__(self, compute_or_sche): if isinstance(compute_or_sche, str): @@ -81,7 +86,7 @@ def get_init_state(self): """ return State(self.init_state, self) - def apply_steps_from_state(self, state, layout_rewrite=False): + def apply_steps_from_state(self, state, layout_rewrite=LAYOUT_REWRITE_TABLE["NoRewrite"]): """ Apply the history transform steps from a State to get a TVM schedule. diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index c6cf094ee202..63b4160b2255 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -862,10 +862,11 @@ std::string GetNewLayout(Array* new_shape, const State& state, const i return new_layout; } -void ComputeDAG::RewriteLayout(const Array& transform_steps) { +void ComputeDAG::RewriteLayout(Array* transform_steps, + LayoutRewriteOption layout_rewrite) { ComputeDAGNode* p_dag = this->CopyOnWrite(); auto node = make_object(); - node->transform_steps = transform_steps; + node->transform_steps = *transform_steps; node->concrete = true; const State& state = InferBound(State(node)); OperationSet handled_ops; @@ -912,13 +913,32 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { handled_ops.insert(placeholder_op); - Array old_ops = p_dag->ops; - ArrayNode* pops = p_dag->ops.CopyOnWrite(); - - // Create new placeholder - te::Operation new_placeholder_op; - new_placeholder_op = te::PlaceholderOp(placeholder_op->name, new_shape, + te::Operation new_op_to_update; + if (layout_rewrite == LayoutRewriteOption::RewriteWithPlaceholder) { + // Create new placeholder + new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape, placeholder_op.as()->dtype); + } else if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { + Array new_stride(new_shape.size(), PrimExpr()); + PrimExpr temp = Integer(1); + for (int i = new_shape.size() - 1; i >= 0; i--) { + new_stride.Set(i, temp); + temp *= new_shape[i]; + } + const auto& layout_transform_tensor = te::compute(new_shape, + [&new_stride, &placeholder_op] + (const tvm::runtime::Array &i) -> tvm::PrimExpr { + return placeholder_op.output(0)( + new_stride[0] * i[0] + new_stride[1] * i[1] + new_stride[2] * i[2] + + new_stride[4] * i[4] + new_stride[6] * i[6], + new_stride[3] * i[3] + new_stride[5] * i[5]); + }, "auto_schedule_layout_transpose"); + new_op_to_update = layout_transform_tensor->op; + } else { + LOG(FATAL) << "Call ComputeDAG::RewriteLayout with NoRewrite."; + } + + Array old_ops = p_dag->ops; te::Operation new_compute_op, old_compute_op; Array new_body; @@ -945,23 +965,32 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { // construct the map from old_op to new_op std::unordered_map updated_ops; + + p_dag->ops.clear(); for (size_t i = 0; i < old_ops.size(); ++i) { - auto old_op = old_ops[i]; + const auto old_op = old_ops[i]; if (old_op == placeholder_op) { - pops->SetItem(i, new_placeholder_op); - updated_ops[placeholder_op] = new_placeholder_op; + if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { + p_dag->ops.push_back(placeholder_op); + } + p_dag->ops.push_back(new_op_to_update); + updated_ops[placeholder_op] = new_op_to_update; } else if (old_op == old_compute_op) { - pops->SetItem(i, new_compute_op); + p_dag->ops.push_back(new_compute_op); updated_ops[old_compute_op] = new_compute_op; } else { - pops->SetItem(i, old_op); + p_dag->ops.push_back(old_op); } } + ArrayNode* pops = p_dag->ops.CopyOnWrite(); // Because ops is sorted in topo-order, only do one pass linear scan here. for (size_t i = 0; i < pops->size(); ++i) { auto old_op = Downcast(pops->at(i)); if (auto* pop = old_op.as()) { + if (old_op == new_op_to_update) { + continue; + } auto inputs = pop->InputTensors(); std::unordered_map rmap; for (auto input : inputs) { @@ -989,6 +1018,10 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { for (size_t i = 0; i < old_tensors.size(); ++i) { const auto& old_tensor = old_tensors[i]; + if (layout_rewrite != LayoutRewriteOption::RewriteWithPlaceholder && + old_tensor->op->IsInstance()) { + continue; + } auto it = updated_ops.find(old_tensor->op); te::Operation new_op; while (it != updated_ops.end()) { @@ -1018,15 +1051,32 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { } p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops); p_dag->init_state = State(p_dag->ops); + + if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { + for (size_t i = 0; i < transform_steps->size(); i++) { + Step ss = (*transform_steps)[i]; + if (ss->stage_id >= 2) { + ss->stage_id++; + } + if (ss->IsInstance()) { + auto ps = tvm::Downcast(ss); + if (ps->target_stage_id >= 2) { + ps->target_stage_id++; + } + } + transform_steps->Set(i, std::move(ss)); + } + } } std::pair> ComputeDAG::ApplySteps( const Array& transform_steps, Array* stages, StageToAxesMap* stage_to_axes, - bool layout_rewrite) const { - if (layout_rewrite && !transform_steps.empty()) { + LayoutRewriteOption layout_rewrite) const { + if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) { ComputeDAG new_dag = *this; - new_dag.RewriteLayout(transform_steps); - return new_dag.ApplySteps(transform_steps, stages, stage_to_axes, false); + Array steps = transform_steps; + new_dag.RewriteLayout(&steps, layout_rewrite); + return new_dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite); } // Temporal object to be used if the input pointer is nullptr @@ -1305,11 +1355,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG") }); TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState") - .set_body_typed([](const ComputeDAG& dag, const State& state, const bool layout_rewrite) { + .set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) { te::Schedule sch; Array return_tensors; std::tie(sch, return_tensors) = - dag.ApplySteps(state->transform_steps, nullptr, nullptr, layout_rewrite); + dag.ApplySteps(state->transform_steps, nullptr, nullptr, + static_cast(layout_rewrite)); return Array{sch, return_tensors}; }); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 23d6eb64da6c..517f7ff91f55 100755 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -445,6 +445,12 @@ String State::ToStr(bool delete_trivial_loop) const { return os.str(); } +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto& stage = tvm::Downcast(ref); + p->stream << stage->GetTypeKey() << "(" << stage.get() << ": " << stage->op->name << ")"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { PrintState(&p->stream, tvm::Downcast(ref), true); From fc91f39e868998a9047d084850c48322e5b60286 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 15 Oct 2020 20:47:05 +0800 Subject: [PATCH 02/14] Update --- include/tvm/auto_scheduler/compute_dag.h | 2 +- src/auto_scheduler/compute_dag.cc | 140 +++++++++++------- .../test_auto_scheduler_layout_rewrite.py | 69 ++++++++- 3 files changed, 146 insertions(+), 65 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 8bf167e380dc..ede60846d4d1 100755 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -231,7 +231,7 @@ class ComputeDAG : public ObjectRef { * according to the loop nest derived with `transform_steps`. * \param transform_steps Transform steps of a state. */ - void RewriteLayout(Array* transform_steps, LayoutRewriteOption layout_rewrite); + ComputeDAG RewriteLayout(Array* transform_steps, LayoutRewriteOption layout_rewrite) const; /*! * \brief Apply the history transform steps to get a TVM schedule. diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 63b4160b2255..8d4a5d9215ac 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -796,7 +796,7 @@ std::string GetOrigLayout(std::set* placeholder_axis_names, const t return orig_layout; } -std::string GetNewLayout(Array* new_shape, const State& state, const int stage_id, +std::string GetNewLayout(const State& state, const int stage_id, const Stage& stage, const te::Operation& op, const te::Tensor& placeholder, const std::set& placeholder_axis_names) { std::ostringstream os; @@ -852,7 +852,6 @@ std::string GetNewLayout(Array* new_shape, const State& state, const i if (placeholder_axis_names.count(ori_iter_name)) { os << iter->range->extent << ori_iter_name; new_names.push_back(ori_iter_name); - new_shape->push_back(iter->range->extent); } } std::string new_layout = os.str(); @@ -862,17 +861,21 @@ std::string GetNewLayout(Array* new_shape, const State& state, const i return new_layout; } -void ComputeDAG::RewriteLayout(Array* transform_steps, - LayoutRewriteOption layout_rewrite) { - ComputeDAGNode* p_dag = this->CopyOnWrite(); +ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, + LayoutRewriteOption layout_rewrite) const { + LOG(INFO) << "rewrite layout in"; + ComputeDAG new_dag = *this; + ComputeDAGNode* p_dag = new_dag.CopyOnWrite(); + auto node = make_object(); node->transform_steps = *transform_steps; node->concrete = true; const State& state = InferBound(State(node)); + OperationSet handled_ops; - int stage_id = -1; - for (const auto& stage : state->stages) { - stage_id += 1; + for (size_t stage_id = 0; stage_id < state->stages.size(); stage_id++) { + const auto& stage = state->stages[stage_id]; + const te::Operation& op = stage->op; if (!op->IsInstance()) { continue; @@ -882,15 +885,13 @@ void ComputeDAG::RewriteLayout(Array* transform_steps, continue; } const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; - Array placeholders = Downcast>(attr_value); - for (const auto& placeholder : placeholders) { + for (const auto& placeholder : Downcast>(attr_value)) { const auto& placeholder_op = placeholder->op; // Check whether this placeholder has already been handled if (handled_ops.count(placeholder_op)) { continue; } - // Skip the op that is not direct consumer of this placeholder. // This is usually caused by cache read/write. bool direct_consumer = false; @@ -903,15 +904,19 @@ void ComputeDAG::RewriteLayout(Array* transform_steps, if (!direct_consumer) { continue; } + handled_ops.insert(placeholder_op); std::set placeholder_axis_names; - GetOrigLayout(&placeholder_axis_names, op, placeholder); + std::string origin_layout = GetOrigLayout(&placeholder_axis_names, op, placeholder); + Array origin_shape; + std::vector origin_axes; + ParseKernelLayout(origin_layout, &origin_shape, &origin_axes); + std::string new_layout = GetNewLayout(state, stage_id, stage, op, placeholder, + placeholder_axis_names); Array new_shape; - std::string new_layout = - GetNewLayout(&new_shape, state, stage_id, stage, op, placeholder, placeholder_axis_names); - - handled_ops.insert(placeholder_op); + std::vector new_axes; + ParseKernelLayout(new_layout, &new_shape, &new_axes); te::Operation new_op_to_update; if (layout_rewrite == LayoutRewriteOption::RewriteWithPlaceholder) { @@ -925,25 +930,53 @@ void ComputeDAG::RewriteLayout(Array* transform_steps, new_stride.Set(i, temp); temp *= new_shape[i]; } + Array access_indices; + for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) { + PrimExpr temp = Integer(0); + for (size_t i = 0; i < new_shape.size(); i++) { + if (origin_axes[indice_index].compare(new_axes[i]) == 0) { + temp += new_shape[i] * new_stride[i]; + } + } + access_indices.push_back(temp); + } + + // Add extra layout transpose stage const auto& layout_transform_tensor = te::compute(new_shape, - [&new_stride, &placeholder_op] - (const tvm::runtime::Array &i) -> tvm::PrimExpr { - return placeholder_op.output(0)( - new_stride[0] * i[0] + new_stride[1] * i[1] + new_stride[2] * i[2] + - new_stride[4] * i[4] + new_stride[6] * i[6], - new_stride[3] * i[3] + new_stride[5] * i[5]); + [&new_stride, &placeholder_op, &access_indices] + (const tvm::runtime::Array& i) -> tvm::PrimExpr { + return placeholder_op.output(0)(access_indices); }, "auto_schedule_layout_transpose"); new_op_to_update = layout_transform_tensor->op; + + // Update the transform steps + LOG(INFO) << stage_id; + for (size_t i = 0; i < transform_steps->size(); i++) { + Step step = (*transform_steps)[i]; + if (step->stage_id >= static_cast(stage_id)) { + step->stage_id++; + // step.CopyOnWrite()->stage_id++; + } + if (step->IsInstance()) { + auto compute_at_step = tvm::Downcast(step); + if (compute_at_step->target_stage_id >= static_cast(stage_id)) { + compute_at_step->target_stage_id++; + } + transform_steps->Set(i, std::move(compute_at_step)); + } else { + transform_steps->Set(i, std::move(step)); + } + } } else { LOG(FATAL) << "Call ComputeDAG::RewriteLayout with NoRewrite."; } - Array old_ops = p_dag->ops; + Array original_ops = p_dag->ops; - te::Operation new_compute_op, old_compute_op; + te::Operation new_compute_op, original_compute_op; Array new_body; IndexRewriter index_rewriter(placeholder_op, new_layout); - for (auto& op : old_ops) { + for (auto& op : original_ops) { if (auto* pop = op.as()) { bool need_update = false; for (auto& t : op->InputTensors()) { @@ -956,39 +989,39 @@ void ComputeDAG::RewriteLayout(Array* transform_steps, for (auto& body : pop->body) { new_body.push_back(index_rewriter.Rewrite(body)); } - old_compute_op = op; - ICHECK(!new_compute_op.defined()); + original_compute_op = op; + CHECK(!new_compute_op.defined()); new_compute_op = te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body); } } } - // construct the map from old_op to new_op + // construct the map from original_op to new_op std::unordered_map updated_ops; p_dag->ops.clear(); - for (size_t i = 0; i < old_ops.size(); ++i) { - const auto old_op = old_ops[i]; - if (old_op == placeholder_op) { + for (size_t i = 0; i < original_ops.size(); ++i) { + const auto& original_op = original_ops[i]; + if (original_op == placeholder_op) { if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { p_dag->ops.push_back(placeholder_op); } p_dag->ops.push_back(new_op_to_update); updated_ops[placeholder_op] = new_op_to_update; - } else if (old_op == old_compute_op) { + } else if (original_op == original_compute_op) { p_dag->ops.push_back(new_compute_op); - updated_ops[old_compute_op] = new_compute_op; + updated_ops[original_compute_op] = new_compute_op; } else { - p_dag->ops.push_back(old_op); + p_dag->ops.push_back(original_op); } } ArrayNode* pops = p_dag->ops.CopyOnWrite(); // Because ops is sorted in topo-order, only do one pass linear scan here. for (size_t i = 0; i < pops->size(); ++i) { - auto old_op = Downcast(pops->at(i)); - if (auto* pop = old_op.as()) { - if (old_op == new_op_to_update) { + const auto& original_op = Downcast(pops->at(i)); + if (auto* pop = original_op.as()) { + if (original_op == new_op_to_update) { continue; } auto inputs = pop->InputTensors(); @@ -1006,8 +1039,8 @@ void ComputeDAG::RewriteLayout(Array* transform_steps, } } if (!rmap.empty()) { - te::Operation new_op = pop->ReplaceInputs(old_op, rmap); - updated_ops[old_op] = new_op; + te::Operation new_op = pop->ReplaceInputs(original_op, rmap); + updated_ops[original_op] = new_op; pops->SetItem(i, new_op); } } @@ -1052,31 +1085,19 @@ void ComputeDAG::RewriteLayout(Array* transform_steps, p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops); p_dag->init_state = State(p_dag->ops); - if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { - for (size_t i = 0; i < transform_steps->size(); i++) { - Step ss = (*transform_steps)[i]; - if (ss->stage_id >= 2) { - ss->stage_id++; - } - if (ss->IsInstance()) { - auto ps = tvm::Downcast(ss); - if (ps->target_stage_id >= 2) { - ps->target_stage_id++; - } - } - transform_steps->Set(i, std::move(ss)); - } - } + LOG(INFO) << "rewrite layout out"; + + return new_dag; } std::pair> ComputeDAG::ApplySteps( const Array& transform_steps, Array* stages, StageToAxesMap* stage_to_axes, LayoutRewriteOption layout_rewrite) const { + LOG(INFO) << this << " " << int(layout_rewrite); if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) { - ComputeDAG new_dag = *this; Array steps = transform_steps; - new_dag.RewriteLayout(&steps, layout_rewrite); - return new_dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite); + const auto& dag = RewriteLayout(&steps, layout_rewrite); + return dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite); } // Temporal object to be used if the input pointer is nullptr @@ -1183,6 +1204,7 @@ State ComputeDAG::InferBound(const State& state) const { te::Schedule sch; Array tensors; // Replay steps to tvm::Schedule + LOG(INFO) << "infer bound"; std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes); sch = sch.normalize(); // Get bound information from TVM schedule @@ -1358,6 +1380,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState") .set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) { te::Schedule sch; Array return_tensors; + LOG(INFO) << state->transform_steps; + for (auto i : state->transform_steps) { + LOG(INFO) << i->stage_id; + } std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, nullptr, nullptr, static_cast(layout_rewrite)); diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index 3ce7a438eef4..026bf7593e1b 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -28,18 +28,27 @@ def test_apply_steps_with_layout_rewrite(): dag, s = get_tiled_matmul() - _, bufs = dag.apply_steps_from_state(s, layout_rewrite=False) + _, bufs = dag.apply_steps_from_state(s) + print("=======") assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512 - _, bufs = dag.apply_steps_from_state(s, layout_rewrite=True) + _, bufs = dag.apply_steps_from_state(s, + layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPlaceholder"]) + print("=======") assert bufs[1].shape[0] == 4 assert bufs[1].shape[1] == 8 assert bufs[1].shape[2] == 4 assert bufs[1].shape[3] == 4 assert bufs[1].shape[4] == 512 + _, bufs = dag.apply_steps_from_state(s, + layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"]) + print("=======") + assert bufs[1].shape[0] == 512 + assert bufs[1].shape[1] == 512 + _, bufs = dag.apply_steps_from_state(s) -def test_layout_rewrite_correctness(): +def test_correctness_layout_rewrite_with_placeholder(): N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) @@ -50,16 +59,18 @@ def test_layout_rewrite_correctness(): search_policy = auto_scheduler.SketchPolicy(task) + measure_ctx = auto_scheduler.LocalRPCMeasureContext() tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, - runner="local", + runner=measure_ctx.runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) - s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True) - s_ref, bufs_ref = dag.apply_steps_from_state(inp.state, layout_rewrite=False) + s, bufs = dag.apply_steps_from_state(inp.state, + layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPlaceholder"]) + s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] np_args_ref = [np.array(x) for x in np_args] @@ -104,6 +115,50 @@ def test_layout_rewrite_correctness(): np.testing.assert_allclose(np_args[2], np_args_ref[2]) +def test_correctness_layout_rewrite_with_pre_transpose(): + N = 128 + target = tvm.target.Target("llvm") + task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) + dag = task.compute_dag + + with tempfile.NamedTemporaryFile() as fp: + log_file = fp.name + + search_policy = auto_scheduler.SketchPolicy(task) + + measure_ctx = auto_scheduler.LocalRPCMeasureContext() + tuning_options = auto_scheduler.TuningOptions( + num_measure_trials=2, + runner=measure_ctx.runner, + verbose=1, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + ) + auto_scheduler.auto_schedule(task, search_policy, tuning_options) + inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) + s, bufs = dag.apply_steps_from_state(inp.state, + layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"]) + s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) + np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] + np_args_ref = [np.array(x) for x in np_args] + + func = tvm.build(s, bufs, target=target) + func_ref = tvm.build(s_ref, bufs_ref, target=target) + + ctx = tvm.context(str(target)) + ctx_ref = tvm.cpu() + + args = [tvm.nd.array(x, ctx=ctx) for x in np_args] + args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref] + ctx.sync() + + func(*args) + func_ref(*args_ref) + ctx.sync() + + np.testing.assert_allclose(np_args, np_args_ref) + + if __name__ == "__main__": test_apply_steps_with_layout_rewrite() - test_layout_rewrite_correctness() + # test_correctness_layout_rewrite_with_placeholder() + # test_correctness_layout_rewrite_with_pre_transpose() From e7d83ef9730efee44fdb008f2256396c8eae58f2 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 16 Oct 2020 17:47:10 +0800 Subject: [PATCH 03/14] Bug fix --- include/tvm/auto_scheduler/transform_step.h | 36 +++++++------ src/auto_scheduler/compute_dag.cc | 16 ++---- src/auto_scheduler/transform_step.cc | 52 +++++++++++++++++++ .../test_auto_scheduler_layout_rewrite.py | 10 ++-- 4 files changed, 80 insertions(+), 34 deletions(-) diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index 94439d72db9c..3d4494ee1a6a 100755 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -134,7 +134,7 @@ class IteratorNode : public Object { } static constexpr const char* _type_key = "auto_scheduler.Iterator"; - TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(IteratorNode, Object); }; /*! @@ -182,7 +182,9 @@ class StepNode : public Object { */ class Step : public ObjectRef { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); + StepNode* CopyOnWrite(); + + TVM_DEFINE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); }; // Forward declaration @@ -267,7 +269,7 @@ class AnnotationStepNode : public StepNode { static constexpr const char* record_prefix_str = "AN"; static constexpr const char* _type_key = "auto_scheduler.AnnotationStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, StepNode); }; /*! @@ -330,7 +332,7 @@ class FuseStepNode : public StepNode { static constexpr const char* record_prefix_str = "FU"; static constexpr const char* _type_key = "auto_scheduler.FuseStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, StepNode); }; /*! @@ -390,7 +392,7 @@ class PragmaStepNode : public StepNode { static constexpr const char* record_prefix_str = "PR"; static constexpr const char* _type_key = "auto_scheduler.PragmaStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, StepNode); }; /*! @@ -452,7 +454,7 @@ class ReorderStepNode : public StepNode { static constexpr const char* record_prefix_str = "RE"; static constexpr const char* _type_key = "auto_scheduler.ReorderStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, StepNode); }; /*! @@ -527,7 +529,7 @@ class SplitStepNode : public StepNode { static constexpr const char* record_prefix_str = "SP"; static constexpr const char* _type_key = "auto_scheduler.SplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, StepNode); }; /*! @@ -607,7 +609,7 @@ class FollowSplitStepNode : public StepNode { static constexpr const char* record_prefix_str = "FSP"; static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, StepNode); }; /*! @@ -688,7 +690,7 @@ class FollowFusedSplitStepNode : public StepNode { static constexpr const char* record_prefix_str = "FFSP"; static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, StepNode); }; /*! @@ -754,7 +756,7 @@ class StorageAlignStepNode : public StepNode { static constexpr const char* record_prefix_str = "SA"; static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, StepNode); }; /*! @@ -822,7 +824,7 @@ class ComputeAtStepNode : public StepNode { static constexpr const char* record_prefix_str = "CA"; static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, StepNode); }; /*! @@ -846,7 +848,7 @@ class ComputeAtStep : public Step { */ explicit ComputeAtStep(dmlc::JSONReader* reader); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); + TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); }; /*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ @@ -879,7 +881,7 @@ class ComputeInlineStepNode : public StepNode { static constexpr const char* record_prefix_str = "CI"; static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, StepNode); }; /*! @@ -938,7 +940,7 @@ class ComputeRootStepNode : public StepNode { static constexpr const char* record_prefix_str = "CR"; static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, StepNode); }; /*! @@ -1010,7 +1012,7 @@ class CacheReadStepNode : public StepNode { static constexpr const char* record_prefix_str = "CHR"; static constexpr const char* _type_key = "auto_scheduler.CacheReadStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, StepNode); }; /*! @@ -1081,7 +1083,7 @@ class CacheWriteStepNode : public StepNode { static constexpr const char* record_prefix_str = "CHW"; static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, StepNode); }; /*! @@ -1148,7 +1150,7 @@ class RfactorStepNode : public StepNode { static constexpr const char* record_prefix_str = "RF"; static constexpr const char* _type_key = "auto_scheduler.RfactorStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, StepNode); }; /*! diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 8d4a5d9215ac..96d7e125adec 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -863,7 +863,6 @@ std::string GetNewLayout(const State& state, const int stage_id, ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, LayoutRewriteOption layout_rewrite) const { - LOG(INFO) << "rewrite layout in"; ComputeDAG new_dag = *this; ComputeDAGNode* p_dag = new_dag.CopyOnWrite(); @@ -950,17 +949,15 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, new_op_to_update = layout_transform_tensor->op; // Update the transform steps - LOG(INFO) << stage_id; for (size_t i = 0; i < transform_steps->size(); i++) { Step step = (*transform_steps)[i]; if (step->stage_id >= static_cast(stage_id)) { - step->stage_id++; - // step.CopyOnWrite()->stage_id++; + step.CopyOnWrite()->stage_id++; } if (step->IsInstance()) { auto compute_at_step = tvm::Downcast(step); if (compute_at_step->target_stage_id >= static_cast(stage_id)) { - compute_at_step->target_stage_id++; + dynamic_cast(step.CopyOnWrite())->target_stage_id++; } transform_steps->Set(i, std::move(compute_at_step)); } else { @@ -1085,18 +1082,16 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops); p_dag->init_state = State(p_dag->ops); - LOG(INFO) << "rewrite layout out"; - return new_dag; } std::pair> ComputeDAG::ApplySteps( const Array& transform_steps, Array* stages, StageToAxesMap* stage_to_axes, LayoutRewriteOption layout_rewrite) const { - LOG(INFO) << this << " " << int(layout_rewrite); if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) { Array steps = transform_steps; const auto& dag = RewriteLayout(&steps, layout_rewrite); + LOG(INFO) << dag; return dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite); } @@ -1204,7 +1199,6 @@ State ComputeDAG::InferBound(const State& state) const { te::Schedule sch; Array tensors; // Replay steps to tvm::Schedule - LOG(INFO) << "infer bound"; std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes); sch = sch.normalize(); // Get bound information from TVM schedule @@ -1380,10 +1374,6 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState") .set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) { te::Schedule sch; Array return_tensors; - LOG(INFO) << state->transform_steps; - for (auto i : state->transform_steps) { - LOG(INFO) << i->stage_id; - } std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, nullptr, nullptr, static_cast(layout_rewrite)); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 852f1e1f17d8..5560907dcffa 100755 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -122,6 +122,58 @@ const char* IteratorAnnotationString[] = { "tensorize" // kTensorized = 11 }; +StepNode* Step::CopyOnWrite() { + CHECK(data_ != nullptr); + if (!data_.unique()) { + if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else { + LOG(FATAL) << "Invalid step: " << (*this); + } + } + return static_cast(data_.get()); +} + Step StepReadFromRecord(dmlc::JSONReader* reader) { std::string name; bool s; diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index 026bf7593e1b..60207a036f90 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -29,12 +29,10 @@ def test_apply_steps_with_layout_rewrite(): dag, s = get_tiled_matmul() _, bufs = dag.apply_steps_from_state(s) - print("=======") assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512 _, bufs = dag.apply_steps_from_state(s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPlaceholder"]) - print("=======") assert bufs[1].shape[0] == 4 assert bufs[1].shape[1] == 8 assert bufs[1].shape[2] == 4 @@ -42,10 +40,8 @@ def test_apply_steps_with_layout_rewrite(): assert bufs[1].shape[4] == 512 _, bufs = dag.apply_steps_from_state(s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"]) - print("=======") assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512 - _, bufs = dag.apply_steps_from_state(s) def test_correctness_layout_rewrite_with_placeholder(): @@ -135,8 +131,14 @@ def test_correctness_layout_rewrite_with_pre_transpose(): ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) + print(">>>") s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"]) + print(bufs) + print("<<<") + print(tvm.lower(s, bufs, simple_mode=True)) + exit(0) + s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] np_args_ref = [np.array(x) for x in np_args] From 6eb5fe75d6a519c412d7dbcccc2fb322119f2af3 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 16 Oct 2020 17:56:21 +0800 Subject: [PATCH 04/14] Bug fix --- src/auto_scheduler/compute_dag.cc | 25 +++++++++---------- .../test_auto_scheduler_layout_rewrite.py | 3 +-- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 96d7e125adec..38496d682e47 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -929,21 +929,20 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, new_stride.Set(i, temp); temp *= new_shape[i]; } - Array access_indices; - for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) { - PrimExpr temp = Integer(0); - for (size_t i = 0; i < new_shape.size(); i++) { - if (origin_axes[indice_index].compare(new_axes[i]) == 0) { - temp += new_shape[i] * new_stride[i]; - } - } - access_indices.push_back(temp); - } - // Add extra layout transpose stage const auto& layout_transform_tensor = te::compute(new_shape, - [&new_stride, &placeholder_op, &access_indices] - (const tvm::runtime::Array& i) -> tvm::PrimExpr { + [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, &new_axes] + (const tvm::runtime::Array& indices) -> tvm::PrimExpr { + Array access_indices; + for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) { + PrimExpr temp = Integer(0); + for (size_t i = 0; i < new_shape.size(); i++) { + if (origin_axes[indice_index].compare(new_axes[i]) == 0) { + temp += indices[i] * new_stride[i]; + } + } + access_indices.push_back(temp); + } return placeholder_op.output(0)(access_indices); }, "auto_schedule_layout_transpose"); new_op_to_update = layout_transform_tensor->op; diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index 60207a036f90..d51a0f55e5e6 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -131,7 +131,6 @@ def test_correctness_layout_rewrite_with_pre_transpose(): ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) - print(">>>") s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"]) print(bufs) @@ -163,4 +162,4 @@ def test_correctness_layout_rewrite_with_pre_transpose(): if __name__ == "__main__": test_apply_steps_with_layout_rewrite() # test_correctness_layout_rewrite_with_placeholder() - # test_correctness_layout_rewrite_with_pre_transpose() + test_correctness_layout_rewrite_with_pre_transpose() From 4aca3865bfc8d58b17df2657d673e3faf9b17a9e Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 24 Oct 2020 17:59:32 +0800 Subject: [PATCH 05/14] Update --- src/auto_scheduler/compute_dag.cc | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 38496d682e47..9a7d3d815e2c 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -923,12 +923,17 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape, placeholder_op.as()->dtype); } else if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { + std::unordered_map axes_stride; + for (const auto& i : origin_axes) { + axes_stride[i] = Integer(1); + } Array new_stride(new_shape.size(), PrimExpr()); PrimExpr temp = Integer(1); for (int i = new_shape.size() - 1; i >= 0; i--) { - new_stride.Set(i, temp); - temp *= new_shape[i]; + new_stride.Set(i, axes_stride[new_axes[i]]); + axes_stride[new_axes[i]] *= new_shape[i]; } + // Add extra layout transpose stage const auto& layout_transform_tensor = te::compute(new_shape, [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, &new_axes] @@ -963,6 +968,12 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, transform_steps->Set(i, std::move(step)); } } + Array to_fuse; + for (size_t i = 0; i < new_shape.size() - 1; i++) { + to_fuse.push_back(i); + } + transform_steps->push_back(FuseStep(stage_id, to_fuse)); + transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel)); } else { LOG(FATAL) << "Call ComputeDAG::RewriteLayout with NoRewrite."; } @@ -1090,7 +1101,6 @@ std::pair> ComputeDAG::ApplySteps( if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) { Array steps = transform_steps; const auto& dag = RewriteLayout(&steps, layout_rewrite); - LOG(INFO) << dag; return dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite); } From 1d21ba13c9b28ad57a3327620426f0d63851c7b1 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 24 Oct 2020 20:53:58 +0800 Subject: [PATCH 06/14] Bug fix --- src/auto_scheduler/compute_dag.cc | 4 ++-- .../test_auto_scheduler_layout_rewrite.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 9a7d3d815e2c..4e806fbc0c06 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -961,7 +961,7 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, if (step->IsInstance()) { auto compute_at_step = tvm::Downcast(step); if (compute_at_step->target_stage_id >= static_cast(stage_id)) { - dynamic_cast(step.CopyOnWrite())->target_stage_id++; + dynamic_cast(compute_at_step.CopyOnWrite())->target_stage_id++; } transform_steps->Set(i, std::move(compute_at_step)); } else { @@ -1101,7 +1101,7 @@ std::pair> ComputeDAG::ApplySteps( if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) { Array steps = transform_steps; const auto& dag = RewriteLayout(&steps, layout_rewrite); - return dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite); + return dag.ApplySteps(steps); } // Temporal object to be used if the input pointer is nullptr diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index d51a0f55e5e6..356743861bef 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -107,8 +107,8 @@ def test_correctness_layout_rewrite_with_placeholder(): func_ref(*args_ref) ctx.sync() - np.testing.assert_allclose(np_args[0], np_args_ref[0]) - np.testing.assert_allclose(np_args[2], np_args_ref[2]) + np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy()) + np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy()) def test_correctness_layout_rewrite_with_pre_transpose(): @@ -133,14 +133,9 @@ def test_correctness_layout_rewrite_with_pre_transpose(): inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"]) - print(bufs) - print("<<<") - print(tvm.lower(s, bufs, simple_mode=True)) - exit(0) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] - np_args_ref = [np.array(x) for x in np_args] func = tvm.build(s, bufs, target=target) func_ref = tvm.build(s_ref, bufs_ref, target=target) @@ -149,17 +144,19 @@ def test_correctness_layout_rewrite_with_pre_transpose(): ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] - args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref] + args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() - np.testing.assert_allclose(np_args, np_args_ref) + np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy()) + np.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy()) + np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy()) if __name__ == "__main__": test_apply_steps_with_layout_rewrite() - # test_correctness_layout_rewrite_with_placeholder() + test_correctness_layout_rewrite_with_placeholder() test_correctness_layout_rewrite_with_pre_transpose() From f819b6000fad959564677c84466e9d6267d609b5 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 26 Oct 2020 10:13:04 +0800 Subject: [PATCH 07/14] CI Fix --- src/auto_scheduler/compute_dag.cc | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 4e806fbc0c06..417f05eb1d0d 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -796,8 +796,8 @@ std::string GetOrigLayout(std::set* placeholder_axis_names, const t return orig_layout; } -std::string GetNewLayout(const State& state, const int stage_id, - const Stage& stage, const te::Operation& op, const te::Tensor& placeholder, +std::string GetNewLayout(const State& state, const int stage_id, const Stage& stage, + const te::Operation& op, const te::Tensor& placeholder, const std::set& placeholder_axis_names) { std::ostringstream os; Array stage_iters; @@ -911,8 +911,8 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, std::vector origin_axes; ParseKernelLayout(origin_layout, &origin_shape, &origin_axes); - std::string new_layout = GetNewLayout(state, stage_id, stage, op, placeholder, - placeholder_axis_names); + std::string new_layout = + GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names); Array new_shape; std::vector new_axes; ParseKernelLayout(new_layout, &new_shape, &new_axes); @@ -935,9 +935,10 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, } // Add extra layout transpose stage - const auto& layout_transform_tensor = te::compute(new_shape, - [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, &new_axes] - (const tvm::runtime::Array& indices) -> tvm::PrimExpr { + const auto& layout_transform_tensor = te::compute( + new_shape, + [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, + &new_axes](const tvm::runtime::Array& indices) -> tvm::PrimExpr { Array access_indices; for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) { PrimExpr temp = Integer(0); @@ -949,7 +950,8 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, access_indices.push_back(temp); } return placeholder_op.output(0)(access_indices); - }, "auto_schedule_layout_transpose"); + }, + "auto_schedule_layout_transpose"); new_op_to_update = layout_transform_tensor->op; // Update the transform steps From 77cea519d9d0d9be8a93beed1202286d89569f5f Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 26 Oct 2020 11:02:35 +0800 Subject: [PATCH 08/14] Update --- include/tvm/auto_scheduler/compute_dag.h | 4 +++- include/tvm/auto_scheduler/transform_step.h | 16 +++++++++++++++- src/auto_scheduler/compute_dag.cc | 12 +++++++----- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index ede60846d4d1..5304cf79a727 100755 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -230,6 +230,8 @@ class ComputeDAG : public ObjectRef { * \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders` * according to the loop nest derived with `transform_steps`. * \param transform_steps Transform steps of a state. + * \param layout_rewrite Different options in layout rewrite. + * \return The updated ComputeDAG after layout rewrite. */ ComputeDAG RewriteLayout(Array* transform_steps, LayoutRewriteOption layout_rewrite) const; @@ -241,7 +243,7 @@ class ComputeDAG : public ObjectRef { * \param stage_to_axes The map that stores all axes for one stage. * Pass a valid pointer if this information needs to be used outside this function. * \param layout_rewrite Rewrite the layout of placeholders specified by - * attr `layout_free_placeholders` + * attr `layout_free_placeholders`. * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower` * or `tvm.build`. */ diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index 3d4494ee1a6a..b1d5a3cf3e9f 100755 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -134,7 +134,7 @@ class IteratorNode : public Object { } static constexpr const char* _type_key = "auto_scheduler.Iterator"; - TVM_DECLARE_BASE_OBJECT_INFO(IteratorNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); }; /*! @@ -182,6 +182,20 @@ class StepNode : public Object { */ class Step : public ObjectRef { public: + /*! + * \brief CopyOnWrite function for Step. + * This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different + * steps. + * \return A base StepNode pointer, need to cast to its real StepNode type before doing any + * modifies. + * \code + * + * SplitStep ref; + * StepNode* mutable_ref = ref.CopyOnWrite(); + * dynamic_cast(mutable_ref)->... = ...; + * + * \endcode + */ StepNode* CopyOnWrite(); TVM_DEFINE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 417f05eb1d0d..edd614ad6b63 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -905,24 +905,28 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, } handled_ops.insert(placeholder_op); + // Process original layout std::set placeholder_axis_names; std::string origin_layout = GetOrigLayout(&placeholder_axis_names, op, placeholder); Array origin_shape; std::vector origin_axes; ParseKernelLayout(origin_layout, &origin_shape, &origin_axes); + // Process new layout std::string new_layout = GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names); Array new_shape; std::vector new_axes; ParseKernelLayout(new_layout, &new_shape, &new_axes); + // Process op updates te::Operation new_op_to_update; if (layout_rewrite == LayoutRewriteOption::RewriteWithPlaceholder) { // Create new placeholder new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape, placeholder_op.as()->dtype); } else if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { + // Process index strides std::unordered_map axes_stride; for (const auto& i : origin_axes) { axes_stride[i] = Integer(1); @@ -980,12 +984,10 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, LOG(FATAL) << "Call ComputeDAG::RewriteLayout with NoRewrite."; } - Array original_ops = p_dag->ops; - te::Operation new_compute_op, original_compute_op; Array new_body; IndexRewriter index_rewriter(placeholder_op, new_layout); - for (auto& op : original_ops) { + for (const auto& op : p_dag->ops) { if (auto* pop = op.as()) { bool need_update = false; for (auto& t : op->InputTensors()) { @@ -995,7 +997,7 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, } } if (need_update) { - for (auto& body : pop->body) { + for (const auto& body : pop->body) { new_body.push_back(index_rewriter.Rewrite(body)); } original_compute_op = op; @@ -1008,6 +1010,7 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, // construct the map from original_op to new_op std::unordered_map updated_ops; + Array original_ops = p_dag->ops; p_dag->ops.clear(); for (size_t i = 0; i < original_ops.size(); ++i) { const auto& original_op = original_ops[i]; @@ -1057,7 +1060,6 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, Array old_tensors = p_dag->tensors; ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite(); - for (size_t i = 0; i < old_tensors.size(); ++i) { const auto& old_tensor = old_tensors[i]; if (layout_rewrite != LayoutRewriteOption::RewriteWithPlaceholder && From 8b2f716b7f866a1746fc2a7f989bd64585619e5b Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 27 Oct 2020 14:15:35 +0800 Subject: [PATCH 09/14] Update --- include/tvm/auto_scheduler/compute_dag.h | 16 +++++++------ include/tvm/auto_scheduler/transform_step.h | 2 +- python/tvm/auto_scheduler/compute_dag.py | 12 +++++----- src/auto_scheduler/compute_dag.cc | 12 +++++----- .../test_auto_scheduler_layout_rewrite.py | 24 ++++++++++++------- 5 files changed, 38 insertions(+), 28 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 5304cf79a727..da0d196f4912 100755 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -195,19 +195,21 @@ class ComputeDAGNode : public Object { }; /*! - * \brief Several options for applying layout rewrite. - * This is a optimization to rewrite the shape of input tensor according to the schedule we get. + * \brief Options for applying layout rewrite. + * This is an optimization to rewrite the layout of input tensors according to the schedule we get. */ enum class LayoutRewriteOption : int { /*! \brief Do not process layout rewrite. */ NoRewrite = 0, + /*! \brief Insert layout transformation stages for input placeholders in the compute DAG */ + InsertTransformStage = 1, /*! - * \brief Modify the placeholder to suit the schedule. - * \note This should be used along with the graph optimization in Relay. + * \brief Do not insert layout transformation stages and assume the input placeholders + * are pre-transformed. + * \note The lowered function with this option does not accept the origial input shapes, + * so this option must be used along with a layout conversion pass in Relay. */ - RewriteWithPlaceholder = 1, - /*! \brief Insert a pre-transpose stage between placeholer and compute op to suit the schedule. */ - RewriteWithPreTranspose = 2 + RewriteForPreTransformed = 2, }; /*! diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index b1d5a3cf3e9f..4cc1551e76fc 100755 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -187,7 +187,7 @@ class Step : public ObjectRef { * This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different * steps. * \return A base StepNode pointer, need to cast to its real StepNode type before doing any - * modifies. + * modifications. * \code * * SplitStep ref; diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 02ac05191585..17600c67fb26 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -50,11 +50,11 @@ class ComputeDAG(Object): compute : Union[List[Tensor], str, Schedule] Input/output tensors or workload key for a compute declaration. """ - LAYOUT_REWRITE_TABLE = { - "NoRewrite": 0, - "RewriteWithPlaceholder": 1, - "RewriteWithPreTranspose": 2, - } + + # Layout Rewrite Options + NoRewrite = 0 + InsertTransformStage = 1 + RewriteForPreTransformed = 2 def __init__(self, compute_or_sche): if isinstance(compute_or_sche, str): @@ -86,7 +86,7 @@ def get_init_state(self): """ return State(self.init_state, self) - def apply_steps_from_state(self, state, layout_rewrite=LAYOUT_REWRITE_TABLE["NoRewrite"]): + def apply_steps_from_state(self, state, layout_rewrite=NoRewrite): """ Apply the history transform steps from a State to get a TVM schedule. diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index edd614ad6b63..090e6daf9859 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -863,6 +863,8 @@ std::string GetNewLayout(const State& state, const int stage_id, const Stage& st ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, LayoutRewriteOption layout_rewrite) const { + CHECK(layout_rewrite != LayoutRewriteOption::NoRewrite) + << "Call ComputeDAG::RewriteLayout with NoRewrite."; ComputeDAG new_dag = *this; ComputeDAGNode* p_dag = new_dag.CopyOnWrite(); @@ -921,11 +923,11 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, // Process op updates te::Operation new_op_to_update; - if (layout_rewrite == LayoutRewriteOption::RewriteWithPlaceholder) { + if (layout_rewrite == LayoutRewriteOption::RewriteForPreTransformed) { // Create new placeholder new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape, placeholder_op.as()->dtype); - } else if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { + } else if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) { // Process index strides std::unordered_map axes_stride; for (const auto& i : origin_axes) { @@ -980,8 +982,6 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, } transform_steps->push_back(FuseStep(stage_id, to_fuse)); transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel)); - } else { - LOG(FATAL) << "Call ComputeDAG::RewriteLayout with NoRewrite."; } te::Operation new_compute_op, original_compute_op; @@ -1015,7 +1015,7 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, for (size_t i = 0; i < original_ops.size(); ++i) { const auto& original_op = original_ops[i]; if (original_op == placeholder_op) { - if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) { + if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) { p_dag->ops.push_back(placeholder_op); } p_dag->ops.push_back(new_op_to_update); @@ -1062,7 +1062,7 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite(); for (size_t i = 0; i < old_tensors.size(); ++i) { const auto& old_tensor = old_tensors[i]; - if (layout_rewrite != LayoutRewriteOption::RewriteWithPlaceholder && + if (layout_rewrite != LayoutRewriteOption::RewriteForPreTransformed && old_tensor->op->IsInstance()) { continue; } diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index 356743861bef..3967dce7aefa 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -31,19 +31,22 @@ def test_apply_steps_with_layout_rewrite(): _, bufs = dag.apply_steps_from_state(s) assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512 - _, bufs = dag.apply_steps_from_state(s, - layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPlaceholder"]) + _, bufs = dag.apply_steps_from_state( + s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.RewriteForPreTransformed + ) assert bufs[1].shape[0] == 4 assert bufs[1].shape[1] == 8 assert bufs[1].shape[2] == 4 assert bufs[1].shape[3] == 4 assert bufs[1].shape[4] == 512 - _, bufs = dag.apply_steps_from_state(s, - layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"]) + _, bufs = dag.apply_steps_from_state( + s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.InsertTransformStage + ) assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512 +@tvm.testing.requires_llvm def test_correctness_layout_rewrite_with_placeholder(): N = 128 target = tvm.target.Target("llvm") @@ -64,8 +67,9 @@ def test_correctness_layout_rewrite_with_placeholder(): ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) - s, bufs = dag.apply_steps_from_state(inp.state, - layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPlaceholder"]) + s, bufs = dag.apply_steps_from_state( + inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.RewriteForPreTransformed + ) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] np_args_ref = [np.array(x) for x in np_args] @@ -109,8 +113,10 @@ def test_correctness_layout_rewrite_with_placeholder(): np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy()) np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy()) + del measure_ctx +@tvm.testing.requires_llvm def test_correctness_layout_rewrite_with_pre_transpose(): N = 128 target = tvm.target.Target("llvm") @@ -131,8 +137,9 @@ def test_correctness_layout_rewrite_with_pre_transpose(): ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) - s, bufs = dag.apply_steps_from_state(inp.state, - layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"]) + s, bufs = dag.apply_steps_from_state( + inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.InsertTransformStage + ) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] @@ -154,6 +161,7 @@ def test_correctness_layout_rewrite_with_pre_transpose(): np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy()) np.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy()) np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy()) + del measure_ctx if __name__ == "__main__": From e9b8d34d67537189a3eca2cc2eedca6f9652dfbc Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 28 Oct 2020 10:23:47 +0800 Subject: [PATCH 10/14] Re-trigger CI --- tests/python/unittest/test_auto_scheduler_task_scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py b/tests/python/unittest/test_auto_scheduler_task_scheduler.py index 72b998a5a38a..df7ec1e47b6f 100644 --- a/tests/python/unittest/test_auto_scheduler_task_scheduler.py +++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py @@ -20,11 +20,13 @@ import numpy as np +import tvm from tvm import auto_scheduler from test_auto_scheduler_common import matmul_auto_scheduler_test +@tvm.testing.requires_llvm def test_task_scheduler_round_robin(): tasks = [] for n in [2, 4, 8]: @@ -68,6 +70,7 @@ def objective_func(costs): task_scheduler.tune(tune_option, search_policy="sketch.random") +@tvm.testing.requires_llvm def test_task_scheduler_gradient(): tasks = [] for n in [2, 4]: From 6b874b2ec038d79f92ee8fd7582b84f69082be12 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 1 Nov 2020 16:14:09 +0800 Subject: [PATCH 11/14] Update --- .../unittest/test_auto_scheduler_layout_rewrite.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index 3967dce7aefa..d09c764aafe8 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -22,6 +22,7 @@ import tvm from tvm import topi from tvm import auto_scheduler, te +import random from test_auto_scheduler_common import get_tiled_matmul, matmul_auto_scheduler_test @@ -47,7 +48,8 @@ def test_apply_steps_with_layout_rewrite(): @tvm.testing.requires_llvm -def test_correctness_layout_rewrite_with_placeholder(): +def test_correctness_layout_rewrite_rewrite_for_preTransformed(): + random.seed(0) N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) @@ -117,7 +119,8 @@ def test_correctness_layout_rewrite_with_placeholder(): @tvm.testing.requires_llvm -def test_correctness_layout_rewrite_with_pre_transpose(): +def test_correctness_layout_rewrite_insert_transform_stage(): + random.seed(0) N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) @@ -166,5 +169,5 @@ def test_correctness_layout_rewrite_with_pre_transpose(): if __name__ == "__main__": test_apply_steps_with_layout_rewrite() - test_correctness_layout_rewrite_with_placeholder() - test_correctness_layout_rewrite_with_pre_transpose() + test_correctness_layout_rewrite_rewrite_for_preTransformed() + test_correctness_layout_rewrite_insert_transform_stage() From fa52e263279a2e135f4ab2825dff9b30dcc4438e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 1 Nov 2020 19:50:26 -0800 Subject: [PATCH 12/14] Update test_auto_scheduler_layout_rewrite.py --- .../unittest/test_auto_scheduler_layout_rewrite.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index d09c764aafe8..e7267c18729a 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -22,7 +22,6 @@ import tvm from tvm import topi from tvm import auto_scheduler, te -import random from test_auto_scheduler_common import get_tiled_matmul, matmul_auto_scheduler_test @@ -49,7 +48,6 @@ def test_apply_steps_with_layout_rewrite(): @tvm.testing.requires_llvm def test_correctness_layout_rewrite_rewrite_for_preTransformed(): - random.seed(0) N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) @@ -113,14 +111,13 @@ def test_correctness_layout_rewrite_rewrite_for_preTransformed(): func_ref(*args_ref) ctx.sync() - np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy()) - np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy()) + tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), rtol=1e-4) + tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), rtol=1e-4) del measure_ctx @tvm.testing.requires_llvm def test_correctness_layout_rewrite_insert_transform_stage(): - random.seed(0) N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) @@ -161,9 +158,10 @@ def test_correctness_layout_rewrite_insert_transform_stage(): func_ref(*args_ref) ctx.sync() - np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy()) - np.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy()) - np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy()) + + tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), rtol=1e-4) + tvm.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy(), rtol=1e-4) + tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), rtol=1e-4) del measure_ctx From 10aa3803795da5272148e8c052d53a7966528b62 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 1 Nov 2020 19:52:34 -0800 Subject: [PATCH 13/14] Update test_auto_scheduler_layout_rewrite.py --- tests/python/unittest/test_auto_scheduler_layout_rewrite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index e7267c18729a..4a11d0fb0ca0 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -158,7 +158,6 @@ def test_correctness_layout_rewrite_insert_transform_stage(): func_ref(*args_ref) ctx.sync() - tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), rtol=1e-4) tvm.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy(), rtol=1e-4) tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), rtol=1e-4) From be5c59f0adb7d4358b91decd357f383a4754c90d Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 2 Nov 2020 14:58:35 +0800 Subject: [PATCH 14/14] Update task_scheduler ut, re-trigger CI --- .../unittest/test_auto_scheduler_task_scheduler.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py b/tests/python/unittest/test_auto_scheduler_task_scheduler.py index 1fb048814ffb..2debc14fc356 100644 --- a/tests/python/unittest/test_auto_scheduler_task_scheduler.py +++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py @@ -22,6 +22,7 @@ import numpy as np import tvm +import tvm.testing from tvm import auto_scheduler from test_auto_scheduler_common import matmul_auto_scheduler_test @@ -41,8 +42,10 @@ def objective_func(costs): num_trials_per_task = 2 # Tune all tasks + measure_ctx = auto_scheduler.LocalRPCMeasureContext() tune_option = auto_scheduler.TuningOptions( num_measure_trials=num_trials_per_task * len(tasks), + runner=measure_ctx.runner, num_measures_per_round=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) @@ -69,13 +72,16 @@ def objective_func(costs): num_measures_per_round=1, ) task_scheduler.tune(tune_option, search_policy="sketch.random") + del measure_ctx +@tvm.testing.requires_llvm def task_scheduler_round_robin_spawn(): assert multiprocessing.get_start_method(False) == "spawn" test_task_scheduler_round_robin() +@tvm.testing.requires_llvm def test_task_scheduler_round_robin_spawn(): ctx = multiprocessing.get_context("spawn") p = ctx.Process(target=task_scheduler_round_robin_spawn) @@ -83,6 +89,7 @@ def test_task_scheduler_round_robin_spawn(): p.join() +@tvm.testing.requires_llvm def test_task_scheduler_gradient(): tasks = [] for n in [2, 4]: @@ -97,8 +104,10 @@ def objective_func(costs): n_trials = 5 # Tune all tasks + measure_ctx = auto_scheduler.LocalRPCMeasureContext() tune_option = auto_scheduler.TuningOptions( num_measure_trials=n_trials, + runner=measure_ctx.runner, num_measures_per_round=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) @@ -120,6 +129,7 @@ def objective_func(costs): assert counters[tasks[0].workload_key] == n_trials - 1 assert counters[tasks[1].workload_key] == 1 + del measure_ctx if __name__ == "__main__":