From b67542315adc8e4faf1ea43f3387f1da1d5b14f4 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 20 Dec 2021 19:52:04 -0800 Subject: [PATCH 1/2] Add measure callbacks. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/meta_schedule/measure_callback.h | 145 ++++++++ include/tvm/meta_schedule/task_scheduler.h | 6 + include/tvm/meta_schedule/tune_context.h | 1 + .../measure_callback/__init__.py | 24 ++ .../measure_callback/add_to_database.py | 30 ++ .../measure_callback/echo_statistics.py | 30 ++ .../measure_callback/measure_callback.py | 104 ++++++ .../measure_callback/remove_build_artifact.py | 30 ++ .../measure_callback/update_cost_model.py | 30 ++ .../measure_callback/add_to_database.cc | 65 ++++ .../measure_callback/echo_statistics.cc | 336 ++++++++++++++++++ .../measure_callback/measure_callback.cc | 50 +++ .../measure_callback/remove_build_artifact.cc | 52 +++ .../measure_callback/update_cost_model.cc | 53 +++ src/meta_schedule/utils.h | 19 + .../test_meta_schedule_measure_callback.py | 132 +++++++ 16 files changed, 1107 insertions(+) create mode 100644 include/tvm/meta_schedule/measure_callback.h create mode 100644 python/tvm/meta_schedule/measure_callback/__init__.py create mode 100644 python/tvm/meta_schedule/measure_callback/add_to_database.py create mode 100644 python/tvm/meta_schedule/measure_callback/echo_statistics.py create mode 100644 python/tvm/meta_schedule/measure_callback/measure_callback.py create mode 100644 python/tvm/meta_schedule/measure_callback/remove_build_artifact.py create mode 100644 python/tvm/meta_schedule/measure_callback/update_cost_model.py create mode 100644 src/meta_schedule/measure_callback/add_to_database.cc create mode 100644 src/meta_schedule/measure_callback/echo_statistics.cc create mode 100644 src/meta_schedule/measure_callback/measure_callback.cc create mode 100644 src/meta_schedule/measure_callback/remove_build_artifact.cc create mode 100644 src/meta_schedule/measure_callback/update_cost_model.cc create mode 100644 tests/python/unittest/test_meta_schedule_measure_callback.py diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h new file mode 100644 index 000000000000..f0763a468cb4 --- /dev/null +++ b/include/tvm/meta_schedule/measure_callback.h @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ +#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +class TaskScheduler; + +/*! \brief Rules to apply after measure results is available. */ +class MeasureCallbackNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MeasureCallbackNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Apply a measure callback rule with given arguments. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builder_results The builder results by building the measure candidates. + * \param runner_results The runner results by running the built measure candidates. + */ + virtual void Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const Array& measure_candidates, // + const Array& builder_results, // + const Array& runner_results) = 0; + + static constexpr const char* _type_key = "meta_schedule.MeasureCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); +}; + +/*! \brief The measure callback with customized methods on the python-side. */ +class PyMeasureCallbackNode : public MeasureCallbackNode { + public: + /*! + * \brief Apply a measure callback to the given schedule. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builds The builder results by building the measure candidates. + * \param results The runner results by running the built measure candidates. + * \return Whether the measure callback was successfully applied. + */ + using FApply = + runtime::TypedPackedFunc& measure_candidates, // + const Array& builds, // + const Array& results)>; + /*! + * \brief Get the measure callback function as string with name. + * \return The string of the measure callback function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results) final { + ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; + return this->f_apply(task_scheduler, task_id, measure_candidates, builds, results); + } + + static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode); +}; + +/*! + * \brief Managed reference to MeasureCallbackNode + * \sa MeasureCallbackNode + */ +class MeasureCallback : public runtime::ObjectRef { + public: + /*! + * \brief Create a measure callback that adds the measurement results into the database + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback AddToDatabase(); + /*! + * \brief Create a measure callback that removes the build artifacts from the disk + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback RemoveBuildArtifact(); + /*! + * \brief Create a measure callback that echos the statistics of the tuning process to the console + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback EchoStatistics(); + /*! + * \brief Create a measure callback that updates the cost model with measurement result. + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback UpdateCostModel(); + /*! + * \brief Create a measure callback with customized methods on the python-side. + * \param f_apply The packed function of `Apply`. + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, + PyMeasureCallbackNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 5841e856e05f..f28c33dc4fe4 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -73,6 +73,10 @@ class TaskSchedulerNode : public runtime::Object { Runner runner{nullptr}; /*! \brief The database of the scheduler. */ Database database{nullptr}; + /*! \brief The cost model of the scheduler. */ + Optional cost_model; + /*! \brief The list of measure callbacks of the scheduler. */ + Array measure_callbacks; /*! \brief The default desctructor. */ virtual ~TaskSchedulerNode() = default; @@ -82,6 +86,8 @@ class TaskSchedulerNode : public runtime::Object { v->Visit("builder", &builder); v->Visit("runner", &runner); v->Visit("database", &database); + v->Visit("cost_model", &cost_model); + v->Visit("measure_callbacks", &measure_callbacks); } /*! \brief Auto-tuning. */ diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 559f2da7f9d9..6eacd4d4f12a 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -20,6 +20,7 @@ #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ #include +#include #include #include #include diff --git a/python/tvm/meta_schedule/measure_callback/__init__.py b/python/tvm/meta_schedule/measure_callback/__init__.py new file mode 100644 index 000000000000..f697e7733e7e --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.measure_callback package. +""" +from .measure_callback import MeasureCallback, PyMeasureCallback +from .add_to_database import AddToDatabase +from .echo_statistics import EchoStatistics +from .remove_build_artifact import RemoveBuildArtifact +from .update_cost_model import UpdateCostModel diff --git a/python/tvm/meta_schedule/measure_callback/add_to_database.py b/python/tvm/meta_schedule/measure_callback/add_to_database.py new file mode 100644 index 000000000000..ab61e87f647d --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/add_to_database.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A callback that adds the measurement results into the database""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .measure_callback import MeasureCallback + + +@register_object("meta_schedule.AddToDatabase") +class AddToDatabase(MeasureCallback): + def __init__(self) -> None: + """A callback that adds the measurement results into the database""" + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackAddToDatabase, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/measure_callback/echo_statistics.py b/python/tvm/meta_schedule/measure_callback/echo_statistics.py new file mode 100644 index 000000000000..867409f88174 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/echo_statistics.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A callback that echos the statistics of the tuning process to the console""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .measure_callback import MeasureCallback + + +@register_object("meta_schedule.EchoStatistics") +class EchoStatistics(MeasureCallback): + def __init__(self) -> None: + """A callback that echos the statistics of the tuning process to the console""" + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackEchoStatistics, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py new file mode 100644 index 000000000000..2b3a36918895 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule MeasureCallback.""" + +from typing import List, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api +from ..builder import BuilderResult +from ..runner import RunnerResult +from ..search_strategy import MeasureCandidate +from ..utils import _get_hex_address, check_override + +if TYPE_CHECKING: + from ..task_scheduler import TaskScheduler + + +@register_object("meta_schedule.MeasureCallback") +class MeasureCallback(Object): + """Rules to apply after measure results is available.""" + + def apply( + self, + task_scheduler: "TaskScheduler", + task_id: int, + measure_candidates: List[MeasureCandidate], + builder_results: List[BuilderResult], + runner_results: List[RunnerResult], + ) -> None: + """Apply a measure callback to the given schedule. + + Parameters + ---------- + task_scheduler: TaskScheduler + The task scheduler. + task_id: int + The task id. + measure_candidates: List[MeasureCandidate] + The measure candidates. + builder_results: List[BuilderResult] + The builder results by building the measure candidates. + runner_results: List[RunnerResult] + The runner results by running the built measure candidates. + """ + return _ffi_api.MeasureCallbackApply( # type: ignore # pylint: disable=no-member + self, + task_scheduler, + task_id, + measure_candidates, + builder_results, + runner_results, + ) + + +@register_object("meta_schedule.PyMeasureCallback") +class PyMeasureCallback(MeasureCallback): + """An abstract MeasureCallback with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, MeasureCallback) + def f_apply( + task_scheduler: "TaskScheduler", + task_id: int, + measure_candidates: List[MeasureCandidate], + builder_results: List[BuilderResult], + runner_results: List[RunnerResult], + ) -> None: + return self.apply( + task_scheduler, + task_id, + measure_candidates, + builder_results, + runner_results, + ) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackPyMeasureCallback, # type: ignore # pylint: disable=no-member + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyMeasureCallback({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py new file mode 100644 index 000000000000..4b2e1ab7f428 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A callback that removes the build artifacts from the disk""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .measure_callback import MeasureCallback + + +@register_object("meta_schedule.RemoveBuildArtifact") +class RemoveBuildArtifact(MeasureCallback): + def __init__(self) -> None: + """A callback that removes the build artifacts from the disk""" + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackRemoveBuildArtifact, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/measure_callback/update_cost_model.py b/python/tvm/meta_schedule/measure_callback/update_cost_model.py new file mode 100644 index 000000000000..c6ee1d26fe6d --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/update_cost_model.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A measure callback that updates the cost model""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .measure_callback import MeasureCallback + + +@register_object("meta_schedule.UpdateCostModel") +class UpdateCostModel(MeasureCallback): + def __init__(self) -> None: + """A measure callback that updates the cost model""" + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackUpdateCostModel, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc new file mode 100644 index 000000000000..b29405333d79 --- /dev/null +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class AddToDatabaseNode : public MeasureCallbackNode { + public: + void Apply(const TaskScheduler& task_scheduler, int task_id, + const Array& measure_candidates, + const Array& builder_results, + const Array& runner_results) final { + TuneContext task = task_scheduler->tasks[task_id]; + Database database = task_scheduler->database; + Workload workload = database->CommitWorkload(task->mod.value()); + Target target = task->target.value(); + ICHECK_EQ(runner_results.size(), measure_candidates.size()); + int n = runner_results.size(); + for (int i = 0; i < n; ++i) { + RunnerResult result = runner_results[i]; + MeasureCandidate candidate = measure_candidates[i]; + if (result->error_msg.defined()) { + continue; + } + database->CommitTuningRecord(TuningRecord( + /*trace=*/candidate->sch->trace().value(), + /*run_secs=*/result->run_secs.value(), + /*workload=*/workload, + /*target=*/target, + /*args_info=*/candidate->args_info)); + } + } + + static constexpr const char* _type_key = "meta_schedule.AddToDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(AddToDatabaseNode, MeasureCallbackNode); +}; + +MeasureCallback MeasureCallback::AddToDatabase() { + ObjectPtr n = make_object(); + return MeasureCallback(n); +} + +TVM_REGISTER_NODE_TYPE(AddToDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackAddToDatabase") + .set_body_typed(MeasureCallback::AddToDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/echo_statistics.cc b/src/meta_schedule/measure_callback/echo_statistics.cc new file mode 100644 index 000000000000..1209e6cedfb4 --- /dev/null +++ b/src/meta_schedule/measure_callback/echo_statistics.cc @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +double CountFlop(const IRModule& mod) { + struct TResult { + using TTable = std::unordered_map; + + TResult() = default; + + explicit TResult(const tvm::DataType& dtype) { Add(dtype); } + + void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; } + + TResult operator+=(const TResult& rhs) { + for (const auto& kv : rhs.data_) { + data_[kv.first] += kv.second; + } + return *this; + } + + TResult operator*=(int64_t rhs) { + for (auto& kv : data_) { + kv.second *= rhs; + } + return *this; + } + + TResult MaxWith(const TResult& rhs) { + for (const auto& kv : rhs.data_) { + double& v = data_[kv.first]; + if (v < kv.second) { + v = kv.second; + } + } + return *this; + } + + struct DType { + uint8_t code : 8; + uint8_t bits : 8; + uint16_t lanes : 16; + }; + static_assert(sizeof(DType) == 4, "Incorrect size of DType"); + + static String Int2Str(int32_t dtype) { + union { + DType dst; + int32_t src; + } converter; + converter.src = dtype; + static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"}; + std::ostringstream os; + os << type_code_tab[converter.dst.code]; + os << static_cast(converter.dst.bits); + if (converter.dst.lanes != 1) { + os << "x" << static_cast(converter.dst.lanes); + } + return os.str(); + } + + static int32_t DataType2Int(const tvm::DataType& dtype) { + union { + DType src; + int32_t dst; + } converter; + converter.src.code = dtype.code(); + converter.src.bits = dtype.bits(); + converter.src.lanes = dtype.lanes(); + return converter.dst; + } + + TTable data_; + }; + + class FlopCounter : public ExprFunctor, + public StmtFunctor { + public: + ~FlopCounter() {} + + TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); } + TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); } + + TResult VisitStmt_(const IfThenElseNode* branch) override { + TResult cond = VisitExpr(branch->condition); + cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case)); + return cond; + } + + TResult VisitStmt_(const BufferStoreNode* store) override { + TResult result = VisitExpr(store->value); + for (const PrimExpr& e : store->indices) { + result += VisitExpr(e); + } + return result; + } + + TResult VisitStmt_(const SeqStmtNode* seq) override { + TResult result; + for (const Stmt& stmt : seq->seq) { + result += VisitStmt(stmt); + } + return result; + } + + TResult VisitStmt_(const BlockRealizeNode* block) override { + return VisitStmt(block->block->body); + } + + TResult VisitStmt_(const BlockNode* block) override { + TResult result; + if (block->init.defined()) { + result += VisitStmt(block->init.value()); + } + result += VisitStmt(block->body); + return result; + } + + TResult VisitStmt_(const ForNode* loop) override { + TResult result = VisitStmt(loop->body); + const auto* int_imm = loop->extent.as(); + ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: " + << loop->extent->GetTypeKey(); + result *= int_imm->value; + return result; + } + +#define TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(Node) \ + TResult VisitExpr_(const Node* op) final { \ + TResult result(op->dtype); \ + result += VisitExpr(op->a); \ + result += VisitExpr(op->b); \ + return result; \ + } + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AddNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(SubNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MulNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(DivNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(ModNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorDivNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorModNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MinNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MaxNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(EQNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(NENode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LTNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LENode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GTNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GENode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AndNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(OrNode); +#undef TVM_META_SCHEDULE_FLOP_COUNTER_BINARY + TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } + TResult VisitExpr_(const VarNode* op) override { return TResult(); } + TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); } + TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); } + TResult VisitExpr_(const IntImmNode* op) override { return TResult(); } + TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } + TResult VisitExpr_(const NotNode* op) override { + TResult result(op->dtype); + result += VisitExpr(op->a); + return result; + } + TResult VisitExpr_(const SelectNode* op) override { + TResult cond = VisitExpr(op->condition); + cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value)); + return cond; + } + TResult VisitExpr_(const CallNode* op) override { + TResult ret; + for (const auto& x : op->args) { + ret += VisitExpr(x); + } + return ret; + } + }; + FlopCounter counter; + TResult result; + for (const auto& kv : mod->functions) { + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + result += counter.VisitStmt(prim_func->body); + } + } + double cnt = 0.0; + int i32 = TResult::DataType2Int(tvm::DataType::Int(32)); + int i64 = TResult::DataType2Int(tvm::DataType::Int(64)); + int u1 = TResult::DataType2Int(tvm::DataType::UInt(1)); + for (const auto& kv : result.data_) { + if (kv.first != i32 && kv.first != i64 && kv.first != u1) { + cnt += kv.second; + } + } + return cnt; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +constexpr const double kMaxTime = 1e10; + +std::string GetTaskName(const TuneContext& task, int task_id) { + std::ostringstream os; + os << '#' << task_id << ": " << task->task_name; + return os.str(); +} + +double GetRunMs(const Array& run_secs) { + double total = 0.0; + for (const FloatImm& i : run_secs) { + total += i->value; + } + return total * 1e3 / run_secs.size(); +} + +struct TaskInfo { + std::string name; + double flop = 0.0; + int trials = 0; + int best_round = -1; + double best_ms = kMaxTime; + double best_gflops = 0.0; + int error_count = 0; + + explicit TaskInfo(const String& name) : name(name) {} + + void Update(double run_ms) { + ++trials; + if (run_ms < best_ms) { + best_ms = run_ms; + best_round = trials; + best_gflops = flop / run_ms / 1e6; + } + LOG(INFO) << "[" << name << "] Trial #" << trials // + << std::fixed << std::setprecision(4) // + << ": GFLOPs: " << (flop / run_ms / 1e6) // + << ". Time: " << run_ms << " ms" // + << ". Best GFLOPs: " << best_gflops; + } + + void UpdateError(std::string err, const MeasureCandidate& candidate) { + static const auto* f_proc = runtime::Registry::Get("meta_schedule._process_error_message"); + ICHECK(f_proc != nullptr); + err = (*f_proc)(err).operator std::string(); + ++error_count; + ++trials; + LOG(INFO) << "[" << name << "] Trial #" << trials // + << std::fixed << std::setprecision(4) // + << ": Error in building: " << err << "\n" + << tir::AsTVMScript(candidate->sch->mod()) << "\n" + << Concat(candidate->sch->trace().value()->AsPython(false), "\n"); + } +}; + +class EchoStatisticsNode : public MeasureCallbackNode { + public: + void Apply(const TaskScheduler& task_scheduler, int task_id, + const Array& measure_candidates, + const Array& builder_results, + const Array& runner_results) final { + if (this->task_info.empty()) { + SetupTaskInfo(task_scheduler->tasks); + } + ICHECK_EQ(measure_candidates.size(), builder_results.size()); + ICHECK_EQ(measure_candidates.size(), runner_results.size()); + int n = measure_candidates.size(); + TuneContext task = task_scheduler->tasks[task_id]; + TaskInfo& info = this->task_info[task_id]; + std::string task_name = GetTaskName(task, task_id); + for (int i = 0; i < n; ++i) { + MeasureCandidate candidate = measure_candidates[i]; + BuilderResult builder_result = builder_results[i]; + RunnerResult runner_result = runner_results[i]; + if (Optional err = builder_result->error_msg) { + info.UpdateError(err.value(), candidate); + } else if (Optional err = runner_result->error_msg) { + info.UpdateError(err.value(), candidate); + } else { + ICHECK(runner_result->run_secs.defined()); + info.Update(GetRunMs(runner_result->run_secs.value())); + } + } + } + + void SetupTaskInfo(const Array& tasks) { + task_info.reserve(tasks.size()); + int task_id = 0; + for (const TuneContext& task : tasks) { + task_info.push_back(TaskInfo(GetTaskName(task, task_id))); + TaskInfo& info = task_info.back(); + info.flop = tir::CountFlop(task->mod.value()); + ++task_id; + } + } + + std::vector task_info; + + static constexpr const char* _type_key = "meta_schedule.EchoStatistics"; + TVM_DECLARE_FINAL_OBJECT_INFO(EchoStatisticsNode, MeasureCallbackNode); +}; + +MeasureCallback MeasureCallback::EchoStatistics() { + ObjectPtr n = make_object(); + return MeasureCallback(n); +} + +TVM_REGISTER_NODE_TYPE(EchoStatisticsNode); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackEchoStatistics") + .set_body_typed(MeasureCallback::EchoStatistics); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc new file mode 100644 index 000000000000..733d118c735d --- /dev/null +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // + PyMeasureCallbackNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return MeasureCallback(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMeasureCallbackNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyMeasureCallback's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") + .set_body_method(&MeasureCallbackNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") + .set_body_typed(MeasureCallback::PyMeasureCallback); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc b/src/meta_schedule/measure_callback/remove_build_artifact.cc new file mode 100644 index 000000000000..649636def112 --- /dev/null +++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class RemoveBuildArtifactNode : public MeasureCallbackNode { + public: + void Apply(const TaskScheduler& task_scheduler, int task_id, + const Array& measure_candidates, + const Array& builder_results, + const Array& runner_results) final { + static const PackedFunc* f_rm = runtime::Registry::Get("meta_schedule.remove_build_dir"); + for (const BuilderResult& build_result : builder_results) { + if (Optional path = build_result->artifact_path) { + (*f_rm)(path.value()); + } + } + } + + static constexpr const char* _type_key = "meta_schedule.RemoveBuildArtifact"; + TVM_DECLARE_FINAL_OBJECT_INFO(RemoveBuildArtifactNode, MeasureCallbackNode); +}; + +MeasureCallback MeasureCallback::RemoveBuildArtifact() { + ObjectPtr n = make_object(); + return MeasureCallback(n); +} + +TVM_REGISTER_NODE_TYPE(RemoveBuildArtifactNode); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackRemoveBuildArtifact") + .set_body_typed(MeasureCallback::RemoveBuildArtifact); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc new file mode 100644 index 000000000000..58c86abadfe9 --- /dev/null +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class UpdateCostModelNode : public MeasureCallbackNode { + public: + void Apply(const TaskScheduler& task_scheduler, int task_id, + const Array& measure_candidates, + const Array& builder_results, + const Array& runner_results) final { + TuneContext task = task_scheduler->tasks[task_id]; + ICHECK(task_scheduler->cost_model.defined()) // + << "Cost model must be defined for the task scheduler!"; + ICHECK(task->measure_candidates.defined()) // + << "Task's measure candidates must be present!"; + CostModel cost_model = task_scheduler->cost_model.value(); + cost_model->Update(task, task->measure_candidates.value(), runner_results); + } + + static constexpr const char* _type_key = "meta_schedule.UpdateCostModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(UpdateCostModelNode, MeasureCallbackNode); +}; + +MeasureCallback MeasureCallback::UpdateCostModel() { + ObjectPtr n = make_object(); + return MeasureCallback(n); +} + +TVM_REGISTER_NODE_TYPE(UpdateCostModelNode); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackUpdateCostModel") + .set_body_typed(MeasureCallback::UpdateCostModel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index f4f95755408c..0a9ce4a1aed9 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -214,6 +215,24 @@ inline std::vector ForkSeed( return results; } +/*! + * \brief Concatenate strings + * \param strs The strings to concatenate + * \param delim The delimiter + * \return The concatenated string + */ +inline std::string Concat(const Array& strs, const std::string& delim) { + if (strs.empty()) { + return ""; + } + std::ostringstream os; + os << strs[0]; + for (int i = 1, n = strs.size(); i < n; ++i) { + os << delim << strs[i]; + } + return os.str(); +} + } // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py new file mode 100644 index 000000000000..b36d6ca7cfbb --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import re +from typing import List + +import pytest +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.meta_schedule.builder import BuilderResult +from tvm.meta_schedule.measure_callback import PyMeasureCallback +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.task_scheduler.task_scheduler import TaskScheduler +from tvm.meta_schedule.utils import _get_hex_address +from tvm.script import tir as T +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_measure_callback(): + class FancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + task_id: int, + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> None: + assert len(measure_candidates) == 1 + assert_structural_equal(measure_candidates[0].sch.mod, Matmul) + assert ( + len(builds) == 1 + and builds[0].error_msg is None + and builds[0].artifact_path == "test_build" + ) + assert ( + len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2 + ) + + measure_callback = FancyMeasureCallback() + measure_callback.apply( + TaskScheduler(), + 0, + [MeasureCandidate(Schedule(Matmul), None)], + [BuilderResult("test_build", None)], + [RunnerResult([1.0, 2.1], None)], + ) + + +def test_meta_schedule_measure_callback_fail(): + class FailingMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + task_id: int, + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> None: + raise ValueError("test") + + measure_callback = FailingMeasureCallback() + with pytest.raises(ValueError, match="test"): + measure_callback.apply( + TaskScheduler(), + 0, + [MeasureCandidate(Schedule(Matmul), None)], + [BuilderResult("test_build", None)], + [RunnerResult([1.0, 2.1], None)], + ) + + +def test_meta_schedule_measure_callback_as_string(): + class NotSoFancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: "TaskScheduler", + task_id: int, + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> None: + pass + + def __str__(self) -> str: + return f"NotSoFancyMeasureCallback({_get_hex_address(self.handle)})" + + measure_callback = NotSoFancyMeasureCallback() + pattern = re.compile(r"NotSoFancyMeasureCallback\(0x[a-f|0-9]*\)") + assert pattern.match(str(measure_callback)) + + +if __name__ == "__main__": + test_meta_schedule_measure_callback() + test_meta_schedule_measure_callback_fail() + test_meta_schedule_measure_callback_as_string() From 0d645605a4d7caa552a848b8aa59df0b7a4974ac Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 20 Dec 2021 20:16:26 -0800 Subject: [PATCH 2/2] Fix comments. --- include/tvm/meta_schedule/measure_callback.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index f0763a468cb4..e9abb123012a 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -41,7 +41,7 @@ class MeasureCallbackNode : public runtime::Object { /*! * \brief Apply a measure callback rule with given arguments. * \param task_scheduler The task scheduler. - * \param tasks The list of tune context to process. + * \param task_id The id of the task (tune context) to apply measure callbacks. * \param measure_candidates The measure candidates. * \param builder_results The builder results by building the measure candidates. * \param runner_results The runner results by running the built measure candidates. @@ -132,6 +132,7 @@ class MeasureCallback : public runtime::ObjectRef { /*! * \brief Create a measure callback with customized methods on the python-side. * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. * \return The measure callback created. */ TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply,