From a2b4a945e45b6723f5cc55cd5beac96af3b5ea14 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 13 Oct 2021 23:40:53 -0700 Subject: [PATCH] Squashed commit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [Meta Schedule][M3c] Schedule Rules, Mutator & Postprocs (#485) [Meta Schedule][M3c] PostOrderApply (#486) Fix Post Order Apply (#490) [MetaSchedule] Relay Integration (#489) [M3c][Meta Schedule] Add Trace Correctness Test for PostOrderApply (#492) Fix replay trace. (#493) [M3c][Meta Schedule] Implement the Replay Func class. (#495) [PR] Test script for meta-schedule task extraction. Interface to load… (#494) [Meta Schedule Refactor] Get child blocks (#500) Read-at && Write-at (#497) [M3c][Meta Schedule] Measure Callbacks (#498) [Bug] Fix Infinite Loop Caused When Calling Methods Not Overrided In PyClass (#496) [MetaSchedule] Sample-Perfect-Tile (#501) Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Wuwei Lin Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> --- include/tvm/meta_schedule/builder.h | 1 + include/tvm/meta_schedule/database.h | 23 +- include/tvm/meta_schedule/integration.h | 214 +++++++++ include/tvm/meta_schedule/measure_callback.h | 126 ++++++ include/tvm/meta_schedule/mutator.h | 125 ++++++ include/tvm/meta_schedule/postproc.h | 130 ++++++ include/tvm/meta_schedule/runner.h | 5 +- include/tvm/meta_schedule/schedule_rule.h | 129 ++++++ include/tvm/meta_schedule/search_strategy.h | 20 +- include/tvm/meta_schedule/space_generator.h | 12 + include/tvm/meta_schedule/task_scheduler.h | 64 ++- include/tvm/meta_schedule/tune_context.h | 19 + include/tvm/tir/schedule/schedule.h | 33 ++ python/tvm/meta_schedule/__init__.py | 4 + python/tvm/meta_schedule/builder/builder.py | 5 +- python/tvm/meta_schedule/database/database.py | 18 +- python/tvm/meta_schedule/integration.py | 238 ++++++++++ .../measure_callback/__init__.py | 20 + .../measure_callback/measure_callback.py | 100 +++++ python/tvm/meta_schedule/mutator/__init__.py | 22 + python/tvm/meta_schedule/mutator/mutator.py | 88 ++++ python/tvm/meta_schedule/postproc/__init__.py | 23 + python/tvm/meta_schedule/postproc/postproc.py | 97 ++++ python/tvm/meta_schedule/runner/runner.py | 5 +- .../meta_schedule/schedule_rule/__init__.py | 19 + .../schedule_rule/schedule_rule.py | 94 ++++ .../meta_schedule/search_strategy/__init__.py | 3 +- .../search_strategy/replay_func.py | 51 +++ .../search_strategy/replay_trace.py | 2 +- .../search_strategy/search_strategy.py | 40 +- .../meta_schedule/space_generator/__init__.py | 2 +- .../space_generator/post_order_apply.py | 36 ++ .../space_generator/space_generator.py | 15 +- .../task_scheduler/round_robin.py | 5 + .../task_scheduler/task_scheduler.py | 140 ++++-- python/tvm/meta_schedule/testing/__init__.py | 19 + .../{testing.py => testing/local_rpc.py} | 2 +- .../meta_schedule/testing/relay_workload.py | 170 +++++++ python/tvm/meta_schedule/tune_context.py | 30 +- python/tvm/meta_schedule/utils.py | 59 +++ python/tvm/te/__init__.py | 2 +- python/tvm/te/operation.py | 18 + python/tvm/tir/schedule/schedule.py | 72 +++ src/meta_schedule/integration.cc | 152 +++++++ .../measure_callback/measure_callback.cc | 50 +++ src/meta_schedule/mutator/mutator.cc | 53 +++ src/meta_schedule/postproc/postproc.cc | 53 +++ .../schedule_rule/schedule_rule.cc | 55 +++ .../search_strategy/replay_func.cc | 134 ++++++ .../search_strategy/replay_trace.cc | 17 +- .../space_generator/post_order_apply.cc | 158 +++++++ .../task_scheduler/round_robin.cc | 8 +- .../task_scheduler/task_scheduler.cc | 63 ++- src/meta_schedule/tune_context.cc | 23 +- src/meta_schedule/utils.h | 16 +- src/relay/backend/te_compiler.cc | 24 +- src/relay/backend/te_compiler_cache.cc | 48 +- src/relay/backend/te_compiler_cache.h | 11 +- src/relay/backend/utils.h | 9 + src/te/operation/create_primfunc.cc | 38 +- src/tir/schedule/analysis/analysis.cc | 4 +- src/tir/schedule/concrete_schedule.cc | 52 +++ src/tir/schedule/concrete_schedule.h | 46 +- src/tir/schedule/primitive.h | 41 +- src/tir/schedule/primitive/get_block_loop.cc | 56 +++ src/tir/schedule/primitive/read_write_at.cc | 425 ++++++++++++++++++ src/tir/schedule/primitive/sampling.cc | 316 ++++++++++++- src/tir/schedule/schedule.cc | 18 + src/tir/schedule/traced_schedule.cc | 63 +++ src/tir/schedule/traced_schedule.h | 18 +- src/tir/schedule/utils.h | 42 +- src/tir/transforms/compact_buffer_region.cc | 2 +- .../test_meta_schedule_measure_callback.py | 135 ++++++ .../unittest/test_meta_schedule_mutator.py | 89 ++++ .../test_meta_schedule_post_order_apply.py | 340 ++++++++++++++ .../unittest/test_meta_schedule_postproc.py | 119 +++++ .../test_meta_schedule_schedule_rule.py | 105 +++++ .../test_meta_schedule_search_strategy.py | 54 ++- .../test_meta_schedule_space_generator.py | 8 + .../test_meta_schedule_task_extraction.py | 84 ++++ .../test_meta_schedule_task_scheduler.py | 79 +++- .../test_tir_schedule_read_write_at.py | 221 +++++++++ .../unittest/test_tir_schedule_sampling.py | 43 +- .../unittest/test_tir_schedule_utilities.py | 17 + 84 files changed, 5343 insertions(+), 246 deletions(-) create mode 100644 include/tvm/meta_schedule/integration.h create mode 100644 include/tvm/meta_schedule/measure_callback.h create mode 100644 include/tvm/meta_schedule/mutator.h create mode 100644 include/tvm/meta_schedule/postproc.h create mode 100644 include/tvm/meta_schedule/schedule_rule.h create mode 100644 python/tvm/meta_schedule/integration.py create mode 100644 python/tvm/meta_schedule/measure_callback/__init__.py create mode 100644 python/tvm/meta_schedule/measure_callback/measure_callback.py create mode 100644 python/tvm/meta_schedule/mutator/__init__.py create mode 100644 python/tvm/meta_schedule/mutator/mutator.py create mode 100644 python/tvm/meta_schedule/postproc/__init__.py create mode 100644 python/tvm/meta_schedule/postproc/postproc.py create mode 100644 python/tvm/meta_schedule/schedule_rule/__init__.py create mode 100644 python/tvm/meta_schedule/schedule_rule/schedule_rule.py create mode 100644 python/tvm/meta_schedule/search_strategy/replay_func.py create mode 100644 python/tvm/meta_schedule/space_generator/post_order_apply.py create mode 100644 python/tvm/meta_schedule/testing/__init__.py rename python/tvm/meta_schedule/{testing.py => testing/local_rpc.py} (97%) create mode 100644 python/tvm/meta_schedule/testing/relay_workload.py create mode 100644 src/meta_schedule/integration.cc create mode 100644 src/meta_schedule/measure_callback/measure_callback.cc create mode 100644 src/meta_schedule/mutator/mutator.cc create mode 100644 src/meta_schedule/postproc/postproc.cc create mode 100644 src/meta_schedule/schedule_rule/schedule_rule.cc create mode 100644 src/meta_schedule/search_strategy/replay_func.cc create mode 100644 src/meta_schedule/space_generator/post_order_apply.cc create mode 100644 src/tir/schedule/primitive/read_write_at.cc create mode 100644 tests/python/unittest/test_meta_schedule_measure_callback.py create mode 100644 tests/python/unittest/test_meta_schedule_mutator.py create mode 100644 tests/python/unittest/test_meta_schedule_post_order_apply.py create mode 100644 tests/python/unittest/test_meta_schedule_postproc.py create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule.py create mode 100644 tests/python/unittest/test_meta_schedule_task_extraction.py create mode 100644 tests/python/unittest/test_tir_schedule_read_write_at.py diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 19358552df..b809843f41 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -137,6 +137,7 @@ class PyBuilderNode : public BuilderNode { } Array Build(const Array& build_inputs) final { + ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!"; return f_build(build_inputs); } diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 7ba3c207e3..60c6898f00 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -230,18 +230,29 @@ class PyDatabaseNode : public DatabaseNode { // `f_size` is not visited } - static constexpr const char* _type_key = "meta_schedule.PyDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode); - - Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); } + Workload CommitWorkload(const IRModule& mod) final { + ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!"; + return f_commit_workload(mod); + } - void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); } + void CommitTuningRecord(const TuningRecord& record) final { + ICHECK(f_commit_tuning_record != nullptr) + << "PyDatabase's CommitTuningRecord method not implemented!"; + f_commit_tuning_record(record); + } Array GetTopK(const Workload& workload, int top_k) final { + ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!"; return f_get_top_k(workload, top_k); } - int64_t Size() final { return f_size(); } + int64_t Size() final { + ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!"; + return f_size(); + } + + static constexpr const char* _type_key = "meta_schedule.PyDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode); }; /*! diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h new file mode 100644 index 0000000000..5f45a2e50b --- /dev/null +++ b/include/tvm/meta_schedule/integration.h @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_INTEGRATION_H_ +#define TVM_META_SCHEDULE_INTEGRATION_H_ + +#include +#include + +#include + +namespace tvm { +namespace meta_schedule { + +/**************** ExtractedTask ****************/ + +/*! + * \brief A tuning task extracted from the high-level IR + */ +class ExtractedTaskNode : public runtime::Object { + public: + /*! \brief The name of the task extracted */ + String task_name; + /*! \brief The high-level IR */ + IRModule mod; + /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */ + Array dispatched; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("task_name", &task_name); + v->Visit("mod", &mod); + v->Visit("dispatched", &dispatched); + } + + static constexpr const char* _type_key = "meta_schedule.ExtractedTask"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object); +}; + +/*! + * \brief Managed reference to ExtractedTaskNode + * \sa ExtractedTaskNode + */ +class ExtractedTask : public runtime::ObjectRef { + public: + /*! + * \brief Constructor. The name of the task extracted + * \brief The high-level IR + * \brief A list of low-level IRs that the high-level IR could potentially dispatch to + */ + explicit ExtractedTask(String task_name, IRModule mod, Array dispatched); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, ExtractedTaskNode); +}; + +/**************** IntegrationContext ****************/ + +/*! + * \brief A context manager interface for the integration + */ +class IntegrationContextNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~IntegrationContextNode() = default; + /*! + * \brief The entry point of the integration + * \param task_name The name of the task + * \param mod The high-level IR + * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to. + * NullOpt means the dispatch needs to be done in the context. + * \return There are different types of the output + * 1) NullOpt if there is no feedback hint + * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc + * 3) relay::Function if `mod` should be dispatched to BYOC workflow + * 4) IRModule for unified dispatch + */ + virtual Optional Query(runtime::String task_name, IRModule mod, + Optional> dispatched) = 0; + + static constexpr const char* _type_key = "meta_schedule.IntegrationContext"; + TVM_DECLARE_BASE_OBJECT_INFO(IntegrationContextNode, runtime::Object); +}; + +/*! + * \brief Managed reference to IntegrationContextNode + * \sa IntegrationContextNode + */ +class IntegrationContext : public runtime::ObjectRef { + friend class IntegrationContextInternal; + friend class With; + + public: + /*! \brief Default destructor */ + virtual ~IntegrationContext() = default; + /*! + * \brief The context manager in the current scope + * \return The IntegrationContext in the current scope. NullOpt if it's currently not under any + * IntegrationContext. + */ + static Optional Current(); + /*! + * \brief The entry point of the integration workflow. The compilation process of the high-level + * IR should call this method for task extraction and for feedback hints + * \param task_name The name of the task + * \param mod The high-level IR + * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to + * \return There are different types of the output + * 1) NullOpt if there is no feedback hint + * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc + * 3) relay::Function if `mod` should be dispatched to BYOC workflow + * 4) IRModule for unified dispatch + */ + static Optional EntryPoint(runtime::String task_name, IRModule mod, + Optional> dispatched); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IntegrationContext, runtime::ObjectRef, + IntegrationContextNode); + + protected: + /*! \brief Default constructor */ + IntegrationContext() = default; + /*! \brief Entering the scope of the context manager */ + void EnterWithScope(); + /*! \brief Exiting the scope of the context manager */ + void ExitWithScope(); +}; + +/**************** TaskExtraction ****************/ + +/*! + * \brief An integration context for task extraction + */ +class TaskExtractionNode : public IntegrationContextNode { + public: + /*! \brief The extracted tasks */ + Array tasks{nullptr}; + + void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); } + + // Inherited from base class + Optional Query(runtime::String task_name, IRModule mod, + Optional> dispatched) final; + + static constexpr const char* _type_key = "meta_schedule.TaskExtraction"; + TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, IntegrationContextNode); +}; + +/*! + * \brief Managed reference to TaskExtractionNode + * \sa TaskExtractionNode + */ +class TaskExtraction : public IntegrationContext { + public: + /*! \brief The path to a cache file storing extracted tasks */ + TaskExtraction(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, IntegrationContext, + TaskExtractionNode); +}; + +/**************** ApplyHistoryBest ****************/ + +/*! + * \brief An integration context that allows application of historically best records from a + * database + */ +class ApplyHistoryBestNode : public IntegrationContextNode { + public: + /*! \brief The database to be queried from */ + Database database{nullptr}; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("database", &database); // + } + + // Inherited from base class + Optional Query(runtime::String task_name, IRModule mod, + Optional> dispatched) final; + + static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; + TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, IntegrationContextNode); +}; + +/*! + * \brief Managed reference to ApplyHistoryBestNode + * \sa ApplyHistoryBestNode + */ +class ApplyHistoryBest : public IntegrationContext { + public: + /*! + * \brief Constructor + * \param database The database to be queried from + */ + explicit ApplyHistoryBest(Database database); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, IntegrationContext, + ApplyHistoryBestNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_INTEGRATION_H_ diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h new file mode 100644 index 0000000000..9ee7039959 --- /dev/null +++ b/include/tvm/meta_schedule/measure_callback.h @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ +#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +class TaskScheduler; + +/*! \brief Rules to apply after measure results is available. */ +class MeasureCallbackNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MeasureCallbackNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Apply a measure callback rule with given arguments. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builds The builder results by building the measure candidates. + * \param results The runner results by running the built measure candidates. + * \return Whether the measure callback was successfully applied. + */ + virtual bool Apply(const TaskScheduler& task_scheduler, // + const Array tasks, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results) = 0; + + static constexpr const char* _type_key = "meta_schedule.MeasureCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); +}; + +/*! \brief The measure callback with customized methods on the python-side. */ +class PyMeasureCallbackNode : public MeasureCallbackNode { + public: + /*! + * \brief Apply a measure callback to the given schedule. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builds The builder results by building the measure candidates. + * \param results The runner results by running the built measure candidates. + * \return Whether the measure callback was successfully applied. + */ + using FApply = + runtime::TypedPackedFunc tasks, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results)>; + /*! + * \brief Get the measure callback function as string with name. + * \return The string of the measure callback function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Apply` funcion. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_apply` is not visited + // `f_as_string` is not visited + } + + bool Apply(const TaskScheduler& task_scheduler, // + const Array tasks, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results) final { + ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; + return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results); + } + + static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode); +}; + +/*! + * \brief Managed reference to MeasureCallbackNode + * \sa MeasureCallbackNode + */ +class MeasureCallback : public runtime::ObjectRef { + public: + /*! + * \brief Create a measure callback with customized methods on the python-side. + * \param f_apply The packed function of `Apply`. + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // + PyMeasureCallbackNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h new file mode 100644 index 0000000000..82f5b76834 --- /dev/null +++ b/include/tvm/meta_schedule/mutator.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MUTATOR_H_ +#define TVM_META_SCHEDULE_MUTATOR_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Mutator is designed to mutate the trace to explore the design space. */ +class MutatorNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MutatorNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + virtual Optional Apply(const tir::Trace& trace) = 0; + + static constexpr const char* _type_key = "meta_schedule.Mutator"; + TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); +}; + +/*! \brief The mutator with customized methods on the python-side. */ +class PyMutatorNode : public MutatorNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + using FApply = runtime::TypedPackedFunc(const tir::Trace&)>; + /*! + * \brief Get the mutator as string with name. + * \return The string of the mutator. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` funcion. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyMutator's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + Optional Apply(const tir::Trace& trace) final { + ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; + return this->f_apply(trace); + } + + static constexpr const char* _type_key = "meta_schedule.PyMutator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); +}; + +/*! + * \brief Managed reference to MutatorNode + * \sa MutatorNode + */ +class Mutator : public runtime::ObjectRef { + public: + /*! + * \brief Create a mutator with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \return The mutator created. + */ + TVM_DLL static Mutator PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MUTATOR_H_ diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h new file mode 100644 index 0000000000..c24861d697 --- /dev/null +++ b/include/tvm/meta_schedule/postproc.h @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_POSTPROC_H_ +#define TVM_META_SCHEDULE_POSTPROC_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! + * \brief Rules to apply a post processing to a schedule. + * \note Post processing is designed to deal with the problem of undertermined schedule validity + * after applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the + * maximum number below 1024, X is only decided at runtime. + */ +class PostprocNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~PostprocNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply a post processing to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the post processing was successfully applied. + */ + virtual bool Apply(const tir::Schedule& schedule) = 0; + + static constexpr const char* _type_key = "meta_schedule.Postproc"; + TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); +}; + +/*! \brief The post processing with customized methods on the python-side. */ +class PyPostprocNode : public PostprocNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply a post processing to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the post processing was successfully applied. + */ + using FApply = runtime::TypedPackedFunc; + /*! + * \brief Get the post processing function as string with name. + * \return The string of the post processing function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` funcion. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyPostproc's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + bool Apply(const tir::Schedule& sch) final { + ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; + return this->f_apply(sch); + } + + static constexpr const char* _type_key = "meta_schedule.PyPostproc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); +}; + +/*! + * \brief Managed reference to PostprocNode + * \sa PostprocNode + */ +class Postproc : public runtime::ObjectRef { + public: + /*! + * \brief Create a post processing with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \return The post processing created. + */ + TVM_DLL static Postproc PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_POSTPROC_H_ diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index c1451ae977..b154195f43 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -207,7 +207,10 @@ class PyRunnerNode : public RunnerNode { // `f_run` is not visited } - Array Run(Array runner_inputs) final { return f_run(runner_inputs); } + Array Run(Array runner_inputs) final { + ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!"; + return f_run(runner_inputs); + } static constexpr const char* _type_key = "meta_schedule.PyRunner"; TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode); diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h new file mode 100644 index 0000000000..92aa46beea --- /dev/null +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_ +#define TVM_META_SCHEDULE_SCHEDULE_RULE_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Rules to modify a block in a schedule. */ +class ScheduleRuleNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~ScheduleRuleNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply a schedule rule to the specific block in the given schedule. + * \param sch The schedule to be modified. + * \param block The specific block to apply the schedule rule. + * \return The list of schedules generated by applying the schedule rule. + */ + virtual runtime::Array Apply(const tir::Schedule& sch, + const tir::BlockRV& block) = 0; + + static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; + TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); +}; + +/*! \brief The schedule rule with customized methods on the python-side. */ +class PyScheduleRuleNode : public ScheduleRuleNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `Apply` method. + * \param sch The schedule to be modified. + * \param block The specific block to apply the schedule rule. + * \return The list of schedules generated by applying the schedule rule. + */ + using FApply = + runtime::TypedPackedFunc(const tir::Schedule&, const tir::BlockRV&)>; + /*! + * \brief Get the schedule rule as string with name. + * \return The string of the schedule rule. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` funcion. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyScheduleRule's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final { + ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; + return this->f_apply(sch, block); + } + + static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); +}; + +/*! + * \brief Managed reference to ScheduleRuleNode + * \sa ScheduleRuleNode + */ +class ScheduleRule : public runtime::ObjectRef { + public: + /*! + * \brief Create a schedule rule with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \return The schedule rule created. + */ + TVM_DLL static ScheduleRule PyScheduleRule( + PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_ diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 941dae4336..3a0fa0ab4a 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -21,6 +21,7 @@ #include #include +#include #include namespace tvm { @@ -187,20 +188,30 @@ class PySearchStrategyNode : public SearchStrategyNode { } void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PySearchStrategy's InitializeWithTuneContext method not implemented!"; this->f_initialize_with_tune_context(context); } void PreTuning(const Array& design_spaces) final { + ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; this->f_pre_tuning(design_spaces); } - void PostTuning() final { this->f_post_tuning(); } + void PostTuning() final { + ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!"; + this->f_post_tuning(); + } Optional> GenerateMeasureCandidates() final { + ICHECK(f_generate_measure_candidates != nullptr) + << "PySearchStrategy's GenerateMeasureCandidates method not implemented!"; return this->f_generate_measure_candidates(); } void NotifyRunnerResults(const Array& results) final { + ICHECK(f_notify_runner_results != nullptr) + << "PySearchStrategy's NotifyRunnerResults method not implemented!"; this->f_notify_runner_results(results); } @@ -237,6 +248,13 @@ class SearchStrategy : public runtime::ObjectRef { */ TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + /*! + * \brief Constructor of replay func search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for func replaying. + */ + TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); }; diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 3dc181e05d..a0dfede820 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -113,10 +113,14 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { } void InitializeWithTuneContext(const TuneContext& tune_context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PySpaceGenerator's InitializeWithTuneContext !"; f_initialize_with_tune_context(tune_context); } Array GenerateDesignSpace(const IRModule& mod) final { + ICHECK(f_generate_design_space != nullptr) + << "PySpaceGenerator's GenerateDesignSpace method not implemented!"; return f_generate_design_space(mod); } @@ -149,6 +153,14 @@ class SpaceGenerator : public ObjectRef { * \return The design space generator created. */ TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators); + /*! + * \brief Create a design space generator that generates design spaces by applying schedule rules + * to blocks in post-DFS order. + * \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`. + * \param generate_design_space_func The packed function of `GenerateDesignSpace`. + * \return The design space generator created. + */ + TVM_DLL static SpaceGenerator PostOrderApply(); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index a2db24e31a..062f493ec3 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -73,6 +74,8 @@ class TaskSchedulerNode : public runtime::Object { Runner runner{nullptr}; /*! \brief The database of the scheduler. */ Database database{nullptr}; + /*! \brief The list of measure callbacks of the scheduler. */ + Array measure_callbacks; /*! \brief The default desctructor. */ virtual ~TaskSchedulerNode() = default; @@ -82,11 +85,18 @@ class TaskSchedulerNode : public runtime::Object { v->Visit("builder", &builder); v->Visit("runner", &runner); v->Visit("database", &database); + v->Visit("measure_callbacks", &measure_callbacks); } /*! \brief Auto-tuning. */ virtual void Tune(); + /*! + * \brief Initialize modules of the given task. + * \param task_id The task id to be initialized. + */ + virtual void InitializeTask(int task_id); + /*! * \brief Set specific task to be stopped. * \param task_id The task id to be stopped. @@ -116,12 +126,17 @@ class TaskSchedulerNode : public runtime::Object { TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object); }; +class TaskScheduler; + /*! \brief The task scheduler with customized methods on the python-side. */ class PyTaskSchedulerNode : public TaskSchedulerNode { public: /*! \brief The function type of `Tune` method. */ using FTune = runtime::TypedPackedFunc; + /*! \brief The function type of `InitializeTask` method. */ + using FInitializeTask = runtime::TypedPackedFunc; + /*! * \brief The function type of `SetTaskStopped` method. * \param task_id The task id to be stopped. @@ -149,6 +164,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { /*! \brief The packed function to the `Tune` funcion. */ FTune f_tune; + /*! \brief The packed function to the `InitializeTask` funcion. */ + FInitializeTask f_initialize_task; /*! \brief The packed function to the `SetTaskStopped` function. */ FSetTaskStopped f_set_task_stopped; /*! \brief The packed function to the `IsTaskRunning` function. */ @@ -160,6 +177,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { void VisitAttrs(tvm::AttrVisitor* v) { // `f_tune` is not visited + // `f_initialize_task` is not visited // `f_set_task_stopped` is not visited // `f_is_task_running` is not visited // `f_join_running_task` is not visited @@ -167,22 +185,47 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { } void Tune() final { // - f_tune(); + if (f_tune == nullptr) { + TaskSchedulerNode::Tune(); + } else { + f_tune(); + } + } + + void InitializeTask(int task_id) final { // + if (f_initialize_task == nullptr) { + TaskSchedulerNode::InitializeTask(task_id); + } else { + f_initialize_task(task_id); + } } void SetTaskStopped(int task_id) final { // - f_set_task_stopped(task_id); + if (f_set_task_stopped == nullptr) { + TaskSchedulerNode::SetTaskStopped(task_id); + } else { + f_set_task_stopped(task_id); + } } bool IsTaskRunning(int task_id) final { // - return f_is_task_running(task_id); + if (f_is_task_running == nullptr) { + return TaskSchedulerNode::IsTaskRunning(task_id); + } else { + return f_is_task_running(task_id); + } } void JoinRunningTask(int task_id) final { // - f_join_running_task(task_id); + if (f_join_running_task == nullptr) { + return TaskSchedulerNode::JoinRunningTask(task_id); + } else { + return f_join_running_task(task_id); + } } int NextTaskId() final { // + ICHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!"; return f_next_task_id(); } @@ -203,10 +246,19 @@ class TaskScheduler : public runtime::ObjectRef { * \param runner The runner of the scheduler. * \param database The database of the scheduler. */ - TVM_DLL static TaskScheduler RoundRobin(Array tasks, Builder builder, Runner runner, - Database database); + TVM_DLL static TaskScheduler RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Array measure_callbacks); TVM_DLL static TaskScheduler PyTaskScheduler( + Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Array measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // + PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, // PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index db72328c91..8ad6aa1d21 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -20,6 +20,10 @@ #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ #include +#include +#include +#include +#include #include #include #include @@ -38,6 +42,12 @@ class TuneContextNode : public runtime::Object { Optional space_generator; /*! \brief The search strategy. */ Optional search_strategy; + /*! \brief The schedule rules. */ + Array sch_rules; + /*! \brief The post processings. */ + Array postprocs; + /*! \brief The mutators. */ + Array mutators; /*! \brief The name of the tuning task. */ Optional task_name; /*! \brief The random state. */ @@ -57,6 +67,9 @@ class TuneContextNode : public runtime::Object { v->Visit("target", &target); v->Visit("space_generator", &space_generator); v->Visit("search_strategy", &search_strategy); + v->Visit("sch_rules", &sch_rules); + v->Visit("postprocs", &postprocs); + v->Visit("mutators", &mutators); v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); @@ -81,6 +94,9 @@ class TuneContext : public runtime::ObjectRef { * \param target The target to be tuned for. * \param space_generator The design space generator. * \param search_strategy The search strategy. + * \param sch_rules The schedule rules. + * \param postprocs The post processings. + * \param mutators The mutators. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. @@ -89,6 +105,9 @@ class TuneContext : public runtime::ObjectRef { Optional target, // Optional space_generator, // Optional search_strategy, // + Array sch_rules, // + Array postprocs, // + Array mutators, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index c4aa1c953a..fa4be5dee4 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -155,6 +155,12 @@ class ScheduleNode : public runtime::Object { * \return The corresponding loop sref */ virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; + /*! + * \brief Check the existance of a specific BlockRV + * \param block_rv The BlockRV to be looked up + * \return Whether the corresponding block exists + */ + virtual bool HasBlock(const BlockRV& block_rv) const = 0; /*! * \brief Get the block/loop sref corresponding to the specific statement * \param stmt The statement to be looked up @@ -194,6 +200,16 @@ class ScheduleNode : public runtime::Object { */ virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) = 0; + /*! + * \brief Sample the factors to perfect tile a specific loop + * \param loop_rv The loop to be tiled + * \param n The number of tiles to be sampled + * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop + * \param decision The sampling decision + * \return A list of length `n`, the random perfect tile sizes sampled + */ + virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, + Optional> decision = NullOpt) = 0; /******** Schedule: Get blocks & loops ********/ /*! @@ -210,6 +226,18 @@ class ScheduleNode : public runtime::Object { * \return A list of loops above the given block in its scope, from outer to inner */ virtual Array GetLoops(const BlockRV& block_rv) = 0; + /*! + * \brief Get the leaf blocks of a specific scope + * \param block_rv The block where the scope is rooted + * \return A list of child blocks + */ + virtual Array GetChildBlocks(const BlockRV& block_rv) = 0; + /*! + * \brief Get the leaf blocks of under a specific loop + * \param loop_rv The loop under which collecting is conducted + * \return A list of child blocks + */ + virtual Array GetChildBlocks(const LoopRV& loop_rv) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Fuse a list of consecutive loops into one. It requires: @@ -305,6 +333,11 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; + /******** Schedule: Data movement ********/ + virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; + virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 2e280ef20a..c57355a391 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -19,6 +19,10 @@ from . import database from . import builder from . import runner +from . import mutator +from . import postproc +from . import schedule_rule from . import space_generator from . import search_strategy +from . import integration from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index ed81f4c0d3..381051e85f 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -23,6 +23,7 @@ from tvm.target import Target from .. import _ffi_api +from ..utils import check_override @register_object("meta_schedule.BuilderInput") @@ -119,6 +120,7 @@ class PyBuilder(Builder): def __init__(self): """Constructor.""" + @check_override(self.__class__, Builder) def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]: return self.build(build_inputs) @@ -126,6 +128,3 @@ def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]: _ffi_api.BuilderPyBuilder, # type: ignore # pylint: disable=no-member f_build, ) - - def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 3d05441fe2..fd746e640c 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -25,7 +25,7 @@ from .. import _ffi_api from ..arg_info import ArgInfo -from ..utils import _json_de_tvm +from ..utils import _json_de_tvm, check_override @register_object("meta_schedule.Workload") @@ -207,15 +207,19 @@ class PyDatabase(Database): def __init__(self): """Constructor.""" + @check_override(self.__class__, Database) def f_commit_workload(mod: IRModule) -> Workload: return self.commit_workload(mod) + @check_override(self.__class__, Database) def f_commit_tuning_record(record: TuningRecord) -> None: self.commit_tuning_record(record) + @check_override(self.__class__, Database) def f_get_top_k(workload: Workload, top_k: int) -> List[TuningRecord]: return self.get_top_k(workload, top_k) + @check_override(self.__class__, Database, func_name="__len__") def f_size() -> int: return len(self) @@ -226,15 +230,3 @@ def f_size() -> int: f_get_top_k, f_size, ) - - def commit_workload(self, mod: IRModule) -> Workload: - raise NotImplementedError - - def commit_tuning_record(self, record: TuningRecord) -> None: - raise NotImplementedError - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - raise NotImplementedError - - def __len__(self) -> int: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py new file mode 100644 index 0000000000..3d99578a24 --- /dev/null +++ b/python/tvm/meta_schedule/integration.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta schedule integration with high-level IR""" +from contextlib import contextmanager +from typing import Callable, Dict, List, Optional, Union + +from tvm._ffi import register_func, register_object +from tvm.ir import IRModule, transform +from tvm.relay import Function as RelayFunc, vm +from tvm.runtime import NDArray, Object +from tvm.target import Target +from tvm.te import Tensor +from tvm.tir import PrimFunc + +from . import _ffi_api + + +@register_object("meta_schedule.ExtractedTask") +class ExtractedTask(Object): + """A tuning task extracted from the high-level IR + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + dispatched : List[IRModule] + A list of low-level IRs that the high-level IR could potentially dispatch to + """ + + task_name: str + mod: IRModule + dispatched: List[IRModule] + + def __init__( + self, + task_name: str, + mod: IRModule, + dispatched: List[IRModule], + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member + task_name, + mod, + dispatched, + ) + + +@register_object("meta_schedule.IntegrationContext") +class IntegrationContext(Object): + """A context manager interface for the integration""" + + def query( + self, + task_name: str, + mod: IRModule, + dispatched: Optional[List[IRModule]], + ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + """The entry point of the integration + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + dispatched : Optional[List[IRModule]] + A list of low-level IRs that the high-level IR could potentially dispatch to + + Returns + ------- + result : Union[IRModule, RelayFunc, PrimFunc, None] + There are different types of the output: + 1) NullOpt if there is no feedback hint; + 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; + 3) relay::Function if `mod` should be dispatched to BYOC workflow; + 4) IRModule for unified dispatch + """ + return _ffi_api.IntegrationContextQuery( # type: ignore # pylint: disable=no-member + self, + task_name, + mod, + dispatched, + ) + + @staticmethod + def current() -> Optional["IntegrationContext"]: + """The context manager in the current scope + + Returns + ------- + ctx : Optional[IntegrationContext] + The IntegrationContext in the current scope. + NullOpt if it's currently not under any IntegrationContext. + """ + return _ffi_api.IntegrationContextCurrent() # type: ignore # pylint: disable=no-member + + @staticmethod + def entry_point( + task_name: str, + mod: IRModule, + dispatched: Optional[List[IRModule]], + ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + """The entry point of the integration workflow. The compilation process of the high-level + IR should call this method for task extraction and for feedback hints + + Parameters + ---------- + task_name : str + The name of the task + mod : IRModule + The high-level IR + dispatched : Optional[List[IRModule]] + A list of low-level IRs that the high-level IR could potentially dispatch to + + Returns + ------- + result : Union[IRModule, RelayFunc, PrimFunc, None] + There are different types of the output: + 1) NullOpt if there is no feedback hint; + 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; + 3) relay::Function if `mod` should be dispatched to BYOC workflow; + 4) IRModule for unified dispatch + """ + return _ffi_api.IntegrationContextEntryPoint( # type: ignore # pylint: disable=no-member + task_name, + mod, + dispatched, + ) + + def __enter__(self) -> "IntegrationContext": + """Entering the scope of the context manager""" + _ffi_api.IntegrationContextEnterScope(self) # type: ignore # pylint: disable=no-member + return self + + def __exit__(self, ptype, value, trace) -> None: + """Exiting the scope of the context manager""" + _ffi_api.IntegrationContextExitScope(self) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.TaskExtraction") +class TaskExtraction(IntegrationContext): + """An integration context for task extraction + + Parameters + ---------- + tasks : List[ExtractedTask] + The extracted tasks + """ + + tasks: List[ExtractedTask] + + def __init__(self) -> None: + self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.ApplyHistoryBest") +class ApplyHistoryBest(IntegrationContext): + pass + + +def extract_task( + mod: Union[IRModule, RelayFunc], + params: Optional[Dict[str, NDArray]], + target: Target, +) -> List[ExtractedTask]: + """Extract tuning tasks from a relay program. + + Parameters + ---------- + mod : tvm.IRModule or relay.Function + The module or function to tune + params : dict of str to numpy array + The associated parameters of the program + target : Union[tvm.target.Target, str] + The compilation target + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this network + """ + + @contextmanager + def _autotvm_silencer(): + from tvm import autotvm + + silent = autotvm.GLOBAL_SCOPE.silent + autotvm.GLOBAL_SCOPE.silent = True + try: + yield + finally: + autotvm.GLOBAL_SCOPE.silent = silent + + def _thread_run(func: Callable[[], None]) -> None: + import threading + + thread = threading.Thread(target=func) + thread.start() + thread.join() + + env = TaskExtraction() + if isinstance(mod, RelayFunc): + mod = IRModule.from_expr(mod) + if not isinstance(target, Target): + target = Target(target) + + def _func(): + with env, _autotvm_silencer(), transform.PassContext( + config={ + "relay.backend.use_meta_schedule": True, + "relay.backend.disable_compile_engine_cache": True, + }, + disabled_pass={}, + opt_level=3, + ): + compiler = vm.VMCompiler() + if params: + compiler.set_params(params) + compiler.lower(mod, target) + + _thread_run(_func) + return env.tasks diff --git a/python/tvm/meta_schedule/measure_callback/__init__.py b/python/tvm/meta_schedule/measure_callback/__init__.py new file mode 100644 index 0000000000..f455c1f4c7 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.measure_callback package. +""" +from .measure_callback import MeasureCallback, PyMeasureCallback diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py new file mode 100644 index 0000000000..f7daed55f6 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule MeasureCallback.""" + +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.runtime import Object + +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..builder import BuilderResult +from ..runner import RunnerResult +from ..utils import _get_hex_address, check_override + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..task_scheduler import TaskScheduler + + +@register_object("meta_schedule.MeasureCallback") +class MeasureCallback(Object): + """Rules to apply after measure results is available.""" + + def apply( + self, + task_scheduler: "TaskScheduler", + tasks: List["TuneContext"], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + """Apply a measure callback to the given schedule. + + Parameters + ---------- + task_scheduler: TaskScheduler + The task scheduler. + tasks: List[TuneContext] + The list of tune context to process. + measure_candidats: List[MeasureCandidate] + The measure candidates. + builds: List[BuilderResult] + The builder results by building the measure candidates. + results: List[RunnerResult] + The runner results by running the built measure candidates. + + Returns + ------- + result : bool + Whether the measure callback was successfully applied. + """ + return _ffi_api.MeasureCallbackApply( + self, task_scheduler, tasks, measure_candidates, builds, results + ) + + +@register_object("meta_schedule.PyMeasureCallback") +class PyMeasureCallback(MeasureCallback): + """An abstract MeasureCallback with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, MeasureCallback) + def f_apply( + task_scheduler: "TaskScheduler", + tasks: List[TuneContext], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + return self.apply(task_scheduler, tasks, measure_candidates, builds, results) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackPyMeasureCallback, # type: ignore # pylint: disable=no-member + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyMeasureCallback({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py new file mode 100644 index 0000000000..f88043b4b4 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.mutator package. +Meta Schedule mutator that mutates the trace to explore the +design space. +""" +from .mutator import Mutator, PyMutator diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py new file mode 100644 index 0000000000..f583154fec --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Mutator.""" +from typing import Optional, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Trace + +from ..utils import _get_hex_address, check_override +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +class Mutator(Object): + """Mutator is designed to mutate the trace to explore the design space.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the mutator with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the mutator. + """ + _ffi_api.MutatorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, trace: Trace) -> Optional[Trace]: + """Apply the mutator function to the given trace. + + Parameters + ---------- + trace : Trace + The given trace for mutation. + + Returns + ------- + trace : Optional[Trace] + None if mutator failed, otherwise return the mutated trace. + """ + return _ffi_api.MutatorApply(self, trace) + + +@register_object("meta_schedule.PyMutator") +class PyMutator(Mutator): + """An abstract mutator with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Mutator) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, Mutator) + def f_apply(trace: Trace) -> Optional[Trace]: + return self.apply(trace) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MutatorPyMutator, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyMutator({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py new file mode 100644 index 0000000000..5316eb4663 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.postproc package. +Meta Schedule post processings that deal with the problem of +undertermined schedule validity after applying some schedule +primitves at runtime. +""" +from .postproc import Postproc, PyPostproc diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py new file mode 100644 index 0000000000..06e0da8fd3 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Postproc.""" + +from typing import TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.Postproc") +class Postproc(Object): + """Rules to apply a post processing to a schedule. + + Note + ---- + Post processing is designed to deal with the problem of undertermined schedule validity after + applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the maximum + number below 1024, X is only decided at runtime. + """ + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the post processing with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the post processing. + """ + _ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, sch: Schedule) -> bool: + """Apply a post processing to the given schedule. + + Parameters + ---------- + sch : Schedule + The schedule to be post processed. + + Returns + ------- + result : bool + Whether the post processing was successfully applied. + """ + return _ffi_api.PostprocApply(self, sch) + + +@register_object("meta_schedule.PyPostproc") +class PyPostproc(Postproc): + """An abstract Postproc with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Postproc) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, Postproc) + def f_apply(sch: Schedule) -> bool: + return self.apply(sch) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.PostprocPyPostproc, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyPostproc({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 9f7be8ea4a..71a557dca3 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -22,6 +22,7 @@ from .. import _ffi_api from ..arg_info import ArgInfo +from ..utils import check_override @register_object("meta_schedule.RunnerInput") @@ -158,6 +159,7 @@ class PyRunner(Runner): def __init__(self) -> None: """Constructor""" + @check_override(self.__class__, Runner) def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: return self.run(runner_inputs) @@ -165,6 +167,3 @@ def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: _ffi_api.RunnerPyRunner, # type: ignore # pylint: disable=no-member f_run, ) - - def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py new file mode 100644 index 0000000000..34a7590b60 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -0,0 +1,19 @@ +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.schedule_rule package. +Meta Schedule schedule rules are used for modification of +blocks in a schedule. See also PostOrderApply. +""" +from .schedule_rule import ScheduleRule, PyScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py new file mode 100644 index 0000000000..ec101410f6 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Meta Schedule schedule rules are used for modification of +blocks in a schedule. See also PostOrderApply. +""" +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule, BlockRV + +from ..utils import _get_hex_address, check_override +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.ScheduleRule") +class ScheduleRule(Object): + """Rules to modify a block in a schedule.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the schedule rule with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the schedule rule. + """ + _ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]: + """Apply a schedule rule to the specific block in the given schedule. + + Parameters + ---------- + sch : Schedule + The schedule to be modified. + block : BlockRV + The specific block to apply the schedule rule. + + Returns + ------- + design_spaces : List[Schedule] + The list of schedules generated by applying the schedule rule. + """ + return _ffi_api.ScheduleRuleApply(self, schedule, block) + + +@register_object("meta_schedule.PyScheduleRule") +class PyScheduleRule(ScheduleRule): + """An abstract schedule rule with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, ScheduleRule) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, ScheduleRule) + def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]: + return self.apply(sch, block) + + def f_as_string() -> str: + return self.__str__() + + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRulePyScheduleRule, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyScheduleRule({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index 609baa2677..e306c307bc 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -20,5 +20,6 @@ to generate measure candidates. """ -from .search_strategy import SearchStrategy, PySearchStrategy +from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate from .replay_trace import ReplayTrace +from .replay_func import ReplayFunc diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py new file mode 100644 index 0000000000..8edd74ab02 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" + +from tvm._ffi import register_object +from .search_strategy import SearchStrategy +from .. import _ffi_api + + +@register_object("meta_schedule.ReplayFunc") +class ReplayFunc(SearchStrategy): + """ + Replay Func Search Strategy is a search strategy that generates measure candidates by + calling a design space generator and transform the design space. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__( + self, + num_trials_per_iter: int, + num_trials_total: int, + ): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyReplayFunc, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 15f8295f25..3fd8bf7a44 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -41,7 +41,7 @@ class ReplayTrace(SearchStrategy): def __init__(self, num_trials_per_iter: int, num_trials_total: int): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.ReplayTrace, # type: ignore # pylint: disable=no-member + _ffi_api.SearchStrategyReplayTrace, # pylint: disable=no-member num_trials_per_iter, num_trials_total, ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index d270ea61f6..e92bbbefca 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -14,17 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Search Strategy""" - +""" +Meta Schedule search strategy that generates the measure +candidates for measurement. +""" from typing import List, Optional, TYPE_CHECKING from tvm._ffi import register_object from tvm.runtime import Object -from tvm.tir.schedule import Schedule +from tvm.tir.schedule import Schedule, Trace from .. import _ffi_api from ..arg_info import ArgInfo from ..runner import RunnerResult +from ..utils import check_override if TYPE_CHECKING: from ..tune_context import TuneContext @@ -45,7 +48,11 @@ class MeasureCandidate(Object): sch: Schedule args_info: List[ArgInfo] - def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: + def __init__( + self, + sch: Schedule, + args_info: List[ArgInfo], + ) -> None: """Constructor. Parameters @@ -69,10 +76,7 @@ class SearchStrategy(Object): before usage and post-tuned after usage. """ - def initialize_with_tune_context( - self, - tune_context: "TuneContext", - ) -> None: + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: """Initialize the search strategy with tuning context. Parameters @@ -126,18 +130,23 @@ class PySearchStrategy(SearchStrategy): def __init__(self): """Constructor.""" + @check_override(self.__class__, SearchStrategy) def f_initialize_with_tune_context(context: "TuneContext") -> None: self.initialize_with_tune_context(context) + @check_override(self.__class__, SearchStrategy) def f_pre_tuning(design_spaces: List[Schedule]) -> None: self.pre_tuning(design_spaces) + @check_override(self.__class__, SearchStrategy) def f_post_tuning() -> None: self.post_tuning() + @check_override(self.__class__, SearchStrategy) def f_generate_measure_candidates() -> List[MeasureCandidate]: return self.generate_measure_candidates() + @check_override(self.__class__, SearchStrategy) def f_notify_runner_results(results: List["RunnerResult"]) -> None: self.notify_runner_results(results) @@ -149,18 +158,3 @@ def f_notify_runner_results(results: List["RunnerResult"]) -> None: f_generate_measure_candidates, f_notify_runner_results, ) - - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - raise NotImplementedError - - def pre_tuning(self, design_spaces: List[Schedule]) -> None: - raise NotImplementedError - - def post_tuning(self) -> None: - raise NotImplementedError - - def generate_measure_candidates(self) -> List[MeasureCandidate]: - raise NotImplementedError - - def notify_runner_results(self, results: List["RunnerResult"]) -> None: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/space_generator/__init__.py b/python/tvm/meta_schedule/space_generator/__init__.py index af759d43b3..fc08cd491d 100644 --- a/python/tvm/meta_schedule/space_generator/__init__.py +++ b/python/tvm/meta_schedule/space_generator/__init__.py @@ -19,7 +19,7 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ - from .space_generator import SpaceGenerator, PySpaceGenerator from .space_generator_union import SpaceGeneratorUnion from .schedule_fn import ScheduleFn +from .post_order_apply import PostOrderApply diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py new file mode 100644 index 0000000000..a9b2d56031 --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Post Order Apply Space Generator.""" + + +from tvm._ffi import register_object +from .space_generator import SpaceGenerator +from .. import _ffi_api + + +@register_object("meta_schedule.PostOrderApply") +class PostOrderApply(SpaceGenerator): + """ + PostOrderApply is the design space generator that generates design spaces by applying schedule + rules to blocks in post-DFS order. + """ + + def __init__(self): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SpaceGeneratorPostOrderApply, # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 798753d913..2172613ce1 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -18,7 +18,6 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ - from typing import TYPE_CHECKING, List from tvm._ffi import register_object @@ -27,6 +26,7 @@ from tvm.tir.schedule import Schedule from .. import _ffi_api +from ..utils import check_override if TYPE_CHECKING: from ..tune_context import TuneContext @@ -36,10 +36,7 @@ class SpaceGenerator(Object): """The abstract design space generator interface.""" - def initialize_with_tune_context( - self, - tune_context: "TuneContext", - ) -> None: + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: """Initialize the design space generator with tuning context. Parameters @@ -74,9 +71,11 @@ class PySpaceGenerator(SpaceGenerator): def __init__(self): """Constructor.""" + @check_override(self.__class__, SpaceGenerator) def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) + @check_override(self.__class__, SpaceGenerator) def f_generate_design_space(mod: IRModule) -> List[Schedule]: return self.generate_design_space(mod) @@ -85,9 +84,3 @@ def f_generate_design_space(mod: IRModule) -> List[Schedule]: f_initialize_with_tune_context, f_generate_design_space, ) - - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - raise NotImplementedError - - def generate_design_space(self, mod: IRModule) -> List[Schedule]: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 391011b4f5..ab2d11ea7b 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -19,6 +19,7 @@ from typing import List, TYPE_CHECKING from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from ..builder import Builder from ..runner import Runner @@ -41,6 +42,7 @@ def __init__( builder: Builder, runner: Runner, database: Database, + measure_callbacks: List[MeasureCallback] = [], ) -> None: """Constructor. @@ -54,6 +56,8 @@ def __init__( The runner. database : Database The database. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. """ self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member @@ -61,4 +65,5 @@ def __init__( builder, runner, database, + measure_callbacks, ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index f1e21ad3dd..2af3852ab7 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -15,19 +15,68 @@ # specific language governing permissions and limitations # under the License. """Auto-tuning Task Scheduler""" + +from typing import List + from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from tvm.runtime import Object +from ..runner import Runner +from ..builder import Builder +from ..database import Database +from ..tune_context import TuneContext from .. import _ffi_api +from ..utils import check_override @register_object("meta_schedule.TaskScheduler") class TaskScheduler(Object): - """The abstract task scheduler interface.""" + """The abstract task scheduler interface. + + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. + """ + + tasks: List[TuneContext] + builder: Builder + runner: Runner + database: Database + measure_callbacks: List[MeasureCallback] def tune(self) -> None: """Auto-tuning.""" - _ffi_api.TaskSchedulerTune(self) # type: ignore # pylint: disable=no-member + _ffi_api.TaskSchedulerTune(self) # pylint: disable=no-member + + def next_task_id(self) -> int: + """Fetch the next task id. + + Returns + ------- + int + The next task id. + """ + return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member + + def _initialize_task(self, task_id: int) -> None: + """Initialize modules of the given task. + + Parameters + ---------- + task_id : int + The task id to be initialized. + """ + _ffi_api.TaskSchedulerInitializeTask(self, task_id) # pylint: disable=no-member def _set_task_stopped(self, task_id: int) -> None: """Set specific task to be stopped. @@ -37,7 +86,7 @@ def _set_task_stopped(self, task_id: int) -> None: task_id : int The task id to be stopped. """ - _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member + _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member def _is_task_running(self, task_id: int) -> bool: """Check whether the task is running. @@ -52,7 +101,7 @@ def _is_task_running(self, task_id: int) -> bool: bool Whether the task is running. """ - return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member + return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member def _join_running_task(self, task_id: int) -> None: """Wait until the task is finished. @@ -62,61 +111,80 @@ def _join_running_task(self, task_id: int) -> None: task_id : int The task id to be joined. """ - _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member - - def _next_task_id(self) -> int: - """Fetch the next task id. - - Returns - ------- - int - The next task id. - """ - return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member + _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member @register_object("meta_schedule.PyTaskScheduler") class PyTaskScheduler(TaskScheduler): """An abstract task scheduler with customized methods on the python-side.""" - def __init__(self): - """Constructor.""" + def __init__( + self, + tasks: List[TuneContext], + builder: Builder, + runner: Runner, + database: Database, + measure_callbacks: List[MeasureCallback] = [], + ): + """Constructor. + + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. + """ + @check_override(self.__class__, TaskScheduler, required=False) def f_tune() -> None: self.tune() + @check_override(self.__class__, TaskScheduler) + def f_next_task_id() -> int: + return self.next_task_id() + + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_initialize_task" + ) + def f_initialize_task(task_id: int) -> None: + self._initialize_task(task_id) + + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_set_task_stopped" + ) def f_set_task_stopped(task_id: int) -> None: self._set_task_stopped(task_id) + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_is_task_running" + ) def f_is_task_running(task_id: int) -> bool: return self._is_task_running(task_id) + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_join_running_task" + ) def f_join_running_task(task_id: int) -> None: self._join_running_task(task_id) - def f_next_task_id() -> int: - return self._next_task_id() - self.__init_handle_by_constructor__( - _ffi_api.TaskSchedulerPyTaskScheduler, # type: ignore # pylint: disable=no-member + _ffi_api.TaskSchedulerPyTaskScheduler, # pylint: disable=no-member + tasks, + builder, + runner, + database, + measure_callbacks, f_tune, + f_initialize_task, f_set_task_stopped, f_is_task_running, f_join_running_task, f_next_task_id, ) - - def tune(self) -> None: - raise NotImplementedError() - - def _set_task_stopped(self, task_id: int) -> None: - _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member - - def _is_task_running(self, task_id: int) -> bool: - return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member - - def _join_running_task(self, task_id: int) -> None: - _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member - - def _next_task_id(self) -> int: - return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py new file mode 100644 index 0000000000..72d84c9bb4 --- /dev/null +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities in meta schedule""" +from .local_rpc import LocalRPC +from .relay_workload import get_torch_model, MODEL_TYPE, MODEL_TYPES diff --git a/python/tvm/meta_schedule/testing.py b/python/tvm/meta_schedule/testing/local_rpc.py similarity index 97% rename from python/tvm/meta_schedule/testing.py rename to python/tvm/meta_schedule/testing/local_rpc.py index b286e3b18a..cd1221124c 100644 --- a/python/tvm/meta_schedule/testing.py +++ b/python/tvm/meta_schedule/testing/local_rpc.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Testing utilities in meta schedule""" +"""RPC tracker and server running locally""" from tvm.rpc.tracker import Tracker from tvm.rpc.server import Server diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py new file mode 100644 index 0000000000..3cb777d6ae --- /dev/null +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -0,0 +1,170 @@ +from tvm import relay +from tvm.ir import IRModule +from typing import Dict, List, Tuple +from tvm.runtime import NDArray +from enum import Enum + +# Model types supported in Torchvision +class MODEL_TYPE(Enum): + IMAGE_CLASSIFICATION = (1,) + VIDEO_CLASSIFICATION = (2,) + SEGMENTATION = (3,) + OBJECT_DETECTION = (4,) + + +# Specify the type of each model +MODEL_TYPES = { + # Image classification models + "resnet50": MODEL_TYPE.IMAGE_CLASSIFICATION, + "alexnet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "vgg16": MODEL_TYPE.IMAGE_CLASSIFICATION, + "squeezenet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet121": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet161": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet169": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet201": MODEL_TYPE.IMAGE_CLASSIFICATION, + "inception_v3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "googlenet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "shufflenet_v2_x1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_large": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_small": MODEL_TYPE.IMAGE_CLASSIFICATION, + "resnext50_32x4d": MODEL_TYPE.IMAGE_CLASSIFICATION, + "wide_resnet50_2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mnasnet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b1": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b4": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b5": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b6": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b7": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + # Semantic Segmentation models + "fcn_resnet50": MODEL_TYPE.SEGMENTATION, + "fcn_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet50": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + "lraspp_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + # Object detection models + # @Sung: Following networks are not runnable since Torch frontend cannot handle aten::remainder. + # "retinanet_resnet50_fpn", "keypointrcnn_resnet50_fpn", + "fasterrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_320_fpn": MODEL_TYPE.OBJECT_DETECTION, + "retinanet_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "maskrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "keypointrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "ssd300_vgg16": MODEL_TYPE.OBJECT_DETECTION, + "ssdlite320_mobilenet_v3_large": MODEL_TYPE.OBJECT_DETECTION, + # Video classification + "r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, +} + + +def get_torch_model( + model_name: str, + input_shape: Tuple[int, ...], + output_shape: Tuple[int, int], + dtype: str = "float32", +) -> Tuple[IRModule, Dict[str, NDArray]]: + """Load model from torch model zoo + Parameters + ---------- + model_name : str + The name of the model to load + input_shape: Tuple[int, ...] + Tuple for input shape + output_shape: Tuple[int, int] + Tuple for output shape + dtype: str + Tensor data type + """ + + assert dtype == "float32" + + import torch + import torchvision.models as models + + def do_trace(model, inp): + model_trace = torch.jit.trace(model, inp) + model_trace.eval() + return model_trace + + # Load model from torchvision + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + model = getattr(models, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + model = getattr(models.segmentation, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + model = getattr(models.detection, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + model = getattr(models.video, model_name)() + else: + raise ValueError("Unsupported model in Torch model zoo.") + + # Setup input + input_data = torch.randn(input_shape).type(torch.float32) + shape_list = [("input0", input_shape)] + + # Get trace. Depending on the model type, wrapper may be necessary. + if MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + + class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return out["out"] + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + scripted_model = do_trace(wrapped_model, input_data) + + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + + def dict_to_tuple(out_dict): + if "masks" in out_dict.keys(): + return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] + return out_dict["boxes"], out_dict["scores"], out_dict["labels"] + + class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return dict_to_tuple(out[0]) + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + out = wrapped_model(input_data) + scripted_model = do_trace(wrapped_model, input_data) + else: + scripted_model = do_trace(model, input_data) + + # Convert torch model to relay module + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + return mod, params diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 0f3cfac1a8..af21908639 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,7 +16,7 @@ # under the License. """Meta Schedule tuning context.""" -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, List from tvm import IRModule from tvm._ffi import register_object @@ -29,6 +29,9 @@ if TYPE_CHECKING: from .space_generator import SpaceGenerator from .search_strategy import SearchStrategy + from .schedule_rule import ScheduleRule + from .postproc import Postproc + from .mutator import Mutator @register_object("meta_schedule.TuneContext") @@ -50,6 +53,12 @@ class TuneContext(Object): The design space generator. search_strategy : Optional[SearchStrategy] = None The search strategy. + sch_rules : List[ScheduleRule] = [] + The schedule rules. + postproc : List[Postproc] = [] + The post processings. + mutator : List[Mutator] = [] + The mutators. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -68,8 +77,11 @@ class TuneContext(Object): mod: Optional[IRModule] target: Optional[Target] - space_generator: "SpaceGenerator" - search_strategy: "SearchStrategy" + space_generator: Optional["SpaceGenerator"] + search_strategy: Optional["SearchStrategy"] + sch_rules: List["ScheduleRule"] + postproc: List["Postproc"] + mutator: List["Mutator"] task_name: Optional[str] rand_state: int num_threads: int @@ -80,6 +92,9 @@ def __init__( target: Optional[Target] = None, space_generator: Optional["SpaceGenerator"] = None, search_strategy: Optional["SearchStrategy"] = None, + sch_rules: List["ScheduleRule"] = [], + postproc: List["Postproc"] = [], + mutator: List["Mutator"] = [], task_name: Optional[str] = None, rand_state: int = -1, num_threads: Optional[int] = None, @@ -96,6 +111,12 @@ def __init__( The design space generator. search_strategy : Optional[SearchStrategy] = None The search strategy. + sch_rules : List[ScheduleRule] = [] + The schedule rules. + postproc : List[Postproc] = [] + The post processings. + mutator : List[Mutator] = [] + The mutators. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -113,6 +134,9 @@ def __init__( target, space_generator, search_strategy, + sch_rules, + postproc, + mutator, task_name, rand_state, num_threads, diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index c79137d55d..39f0fe38fc 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Utilities for meta schedule""" +import ctypes import json import os import shutil @@ -22,6 +23,7 @@ import psutil # type: ignore import tvm +from tvm import meta_schedule from tvm._ffi import get_global_func, register_func from tvm.error import TVMError from tvm.ir import Array, Map, IRModule @@ -205,3 +207,60 @@ def structural_hash(mod: IRModule) -> str: # but ffi can't handle unsigned integers properly so it's parsed into a negative number shash += 1 << 64 return str(shash) + + +def _get_hex_address(handle: ctypes.c_void_p) -> str: + """Get the hexadecimal address of a handle. + + Parameters + ---------- + handle : ctypes.c_void_p + The handle to be converted. + + Returns + ------- + result : str + The hexadecimal address of the handle. + """ + return hex(ctypes.cast(handle, ctypes.c_void_p).value) + + +def check_override( + derived_class: Any, base_class: Any, required: bool = True, func_name: str = None +) -> Callable: + """Check if the derived class has overrided the base class's method. + + Parameters + ---------- + derived_class : Any + The derived class. + base_class : Any + The base class of derived class. + required : bool + If the method override is required. + func_name : str + Name of the method. Default value None, which would be set to substring of the given + function, e.g. `f_generate`->`generate`. + + Returns + ------- + func : Callable + Raise NotImplementedError if the function is required and not overrided. If the + function is not overrided return None, other return the overrided function. + """ + + def inner(func: Callable): + + if func_name is None: + method = func.__name__[2:] + else: + method = func_name + + if getattr(derived_class, method) is getattr(base_class, method): + if required: + raise NotImplementedError(f"{derived_class}'s {method} method is not implemented!") + else: + return None + return func + + return inner diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 250c165caf..308257085e 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -33,7 +33,7 @@ from .tag import tag_scope from .operation import placeholder, compute, scan, extern, var, size_var from .operation import thread_axis, reduce_axis -from .operation import create_prim_func +from .operation import create_prim_func, create_prim_func_from_outputs from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp from .autodiff import gradient diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index cb0305d49e..796195b488 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -482,3 +482,21 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: if not isinstance(ops, (list, tuple, Array)): ops = [ops] return _ffi_api.CreatePrimFunc(ops) + + +def create_prim_func_from_outputs(outputs: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: + """Create a TensorIR PrimFunc from output tensors in TE + + Parameters + ---------- + outputs : List[Tensor] + The source expression. + + Returns + ------- + func : tir.PrimFunc + The created function. + """ + if not isinstance(outputs, (list, tuple, Array)): + outputs = [outputs] + return _ffi_api.CreatePrimFuncFromOutputs(outputs) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 786982cf70..3bb56aeab0 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -325,6 +325,39 @@ def sample_categorical( decision, ) + def sample_perfect_tile( + self, + loop: LoopRV, + n: int, + max_innermost_factor: int = 16, + decision: Optional[List[int]] = None, + ) -> List[ExprRV]: + """Sample the factors to perfect tile a specific loop + + Parameters + ---------- + loop : LoopRV + The loop to be tiled + n : int + The number of tiles to be sampled + max_innermost_factor : int + The maximum tile size allowed to be sampled in the innermost loop + decision: Optional[List[int]] + The sampling decision, if any + + Returns + ------- + result : List[ExprRV] + A list of length `n`, the random perfect tile sizes sampled + """ + return _ffi_api.ScheduleSamplePerfectTile( # type: ignore # pylint: disable=no-member + self, + loop, + n, + max_innermost_factor, + decision, + ) + ########## Schedule: Get blocks & loops ########## def get_block( self, @@ -367,6 +400,21 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: """ return _ffi_api.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member + def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockRV]: + """Get the leaf blocks of a specific block/loop + + Parameters + ---------- + block_or_loop : Union[BlockRV, LoopRV] + The query block/loop + + Returns + ------- + blocks : List[LoopRV] + A list of leaf blocks inside a specific block/loop + """ + return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop) # pylint: disable=no-member + ########## Schedule: Transform loops ########## def fuse(self, *loops: List[LoopRV]) -> LoopRV: """Fuse a list of consecutive loops into one. It requires: @@ -926,6 +974,30 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope ) + ########## Schedule: Data movement ########## + + def read_at( + self, + loop: LoopRV, + block: BlockRV, + read_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member + self, loop, block, read_buffer_index, storage_scope + ) + + def write_at( + self, + loop: LoopRV, + block: BlockRV, + write_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member + self, loop, block, write_buffer_index, storage_scope + ) + ########## Schedule: Compute location ########## def compute_at( diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc new file mode 100644 index 0000000000..b1d549692a --- /dev/null +++ b/src/meta_schedule/integration.cc @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/**************** Utility functions ****************/ + +template +bool HasOnlyOneFunction(const IRModule& mod) { + if (mod->functions.size() != 1) { + return false; + } + for (const auto& kv : mod->functions) { + const BaseFunc& func = kv.second; + if (!func->IsInstance()) { + return false; + } + } + return true; +} + +/**************** ExtractedTask ****************/ + +ExtractedTask::ExtractedTask(String task_name, IRModule mod, Array dispatched) { + ObjectPtr n = make_object(); + n->task_name = task_name; + n->mod = mod; + n->dispatched = dispatched; + data_ = n; +} + +/**************** IntegrationContext ****************/ + +struct IntegrationContextThreadLocalEntry { + Optional ctx; +}; + +using IntegrationContextThreadLocalStore = + dmlc::ThreadLocalStore; + +Optional IntegrationContext::Current() { + return IntegrationContextThreadLocalStore::Get()->ctx; +} + +void IntegrationContext::EnterWithScope() { + Optional& ctx = IntegrationContextThreadLocalStore::Get()->ctx; + CHECK(!ctx.defined()) << "ValueError: Nested IntegrationContext context managers are not allowed"; + ctx = *this; +} + +void IntegrationContext::ExitWithScope() { + Optional& ctx = IntegrationContextThreadLocalStore::Get()->ctx; + ICHECK(ctx.defined()); + ctx = NullOpt; +} + +Optional IntegrationContext::EntryPoint(runtime::String task_name, IRModule mod, + Optional> dispatched) { + if (Optional ctx = IntegrationContext::Current()) { + return ctx.value()->Query(task_name, mod, dispatched); + } + return NullOpt; +} + +/**************** TaskExtraction ****************/ + +TaskExtraction::TaskExtraction() { + ObjectPtr n = make_object(); + n->tasks = Array(); + data_ = n; +} + +Optional TaskExtractionNode::Query(runtime::String task_name, IRModule mod, + Optional> dispatched) { + ICHECK(dispatched.defined()); + ICHECK_EQ(dispatched.value().size(), 1); + IRModule prim_mod = dispatched.value()[0]; + ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + ICHECK(HasOnlyOneFunction(mod)) << mod; + tasks.push_back(ExtractedTask(task_name, mod, {prim_mod})); + LOG(INFO) << "relay_mod:\n" << mod << "\nprim_mod:\n" << prim_mod; + return NullOpt; +} + +/**************** ApplyHistoryBest ****************/ + +ApplyHistoryBest::ApplyHistoryBest(Database database) { + ObjectPtr n = make_object(); + n->database = database; + data_ = n; +} + +Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, + Optional> dispatched) { + // TODO + throw; +} + +/**************** FFI ****************/ + +class IntegrationContextInternal { + public: + static void EnterScope(IntegrationContext ctx) { ctx.EnterWithScope(); } + static void ExitScope(IntegrationContext ctx) { ctx.ExitWithScope(); } +}; + +TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); +TVM_REGISTER_OBJECT_TYPE(IntegrationContextNode); +TVM_REGISTER_NODE_TYPE(TaskExtractionNode); +TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); + +TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") + .set_body_typed([](String task_name, IRModule mod, + Array dispatched) -> ExtractedTask { + return ExtractedTask(task_name, mod, dispatched); + }); +TVM_REGISTER_GLOBAL("meta_schedule.IntegrationContextEnterScope") + .set_body_typed(IntegrationContextInternal::EnterScope); +TVM_REGISTER_GLOBAL("meta_schedule.IntegrationContextExitScope") + .set_body_typed(IntegrationContextInternal::ExitScope); +TVM_REGISTER_GLOBAL("meta_schedule.IntegrationContextCurrent") + .set_body_typed(IntegrationContext::Current); +TVM_REGISTER_GLOBAL("meta_schedule.IntegrationContextEntryPoint") + .set_body_typed(IntegrationContext::EntryPoint); +TVM_REGISTER_GLOBAL("meta_schedule.IntegrationContextQuery") + .set_body_method(&IntegrationContextNode::Query); +TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { + return TaskExtraction(); +}); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc new file mode 100644 index 0000000000..733d118c73 --- /dev/null +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // + PyMeasureCallbackNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return MeasureCallback(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMeasureCallbackNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyMeasureCallback's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") + .set_body_method(&MeasureCallbackNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") + .set_body_typed(MeasureCallback::PyMeasureCallback); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc new file mode 100644 index 0000000000..9bf6161b55 --- /dev/null +++ b/src/meta_schedule/mutator/mutator.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Mutator Mutator::PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Mutator(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMutatorNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MutatorNode); +TVM_REGISTER_NODE_TYPE(PyMutatorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") + .set_body_method(&MutatorNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply").set_body_method(&MutatorNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc new file mode 100644 index 0000000000..ff069e2c68 --- /dev/null +++ b/src/meta_schedule/postproc/postproc.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Postproc Postproc::PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Postproc(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyPostprocNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(PostprocNode); +TVM_REGISTER_NODE_TYPE(PyPostprocNode); + +TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") + .set_body_method(&PostprocNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc new file mode 100644 index 0000000000..f80f684daf --- /dev/null +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +ScheduleRule ScheduleRule::PyScheduleRule( + PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return ScheduleRule(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyScheduleRuleNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyScheduleRule's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(ScheduleRuleNode); +TVM_REGISTER_NODE_TYPE(PyScheduleRuleNode); + +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext") + .set_body_method(&ScheduleRuleNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply") + .set_body_method(&ScheduleRuleNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") + .set_body_typed(ScheduleRule::PyScheduleRule); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc new file mode 100644 index 0000000000..5c00b3dc04 --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that generates measure candidates using space generator. */ +class ReplayFuncNode : public SearchStrategyNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayFuncNode* self; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The space generator for measure candidates generation. */ + SpaceGenerator space_generator_{nullptr}; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `space_generator_` is not visited + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->space_generator_ = tune_context->space_generator.value(); + this->mod_ = tune_context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + Array result; + for (int i = st; i < ed; i++) { + Array schs = self->space_generator_->GenerateDesignSpace(self->mod_); + result.push_back(MeasureCandidate(schs[tir::SampleInt(&self->rand_state_, 0, schs.size())], + self->args_info_)); + } + return result; +} + +inline void ReplayFuncNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayFuncNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") + .set_body_typed(SearchStrategy::ReplayFunc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 1c83aee8c0..c4ee3c4679 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -50,7 +50,7 @@ class ReplayTraceNode : public SearchStrategyNode { int num_trials_total; /*! \brief The module to be tuned. */ - IRModule mod_{nullptr}; + Array mod_{nullptr}; /*! \brief The metadata of the function arguments. */ Array args_info_{nullptr}; /*! \brief The number of threads to use. -1 means using logical cpu number. */ @@ -74,9 +74,15 @@ class ReplayTraceNode : public SearchStrategyNode { TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& tune_context) final { - this->mod_ = tune_context->mod.value(); - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); + CHECK(tune_context->num_threads > 0) << "Number of threads has to be larger than 0."; this->num_threads_ = tune_context->num_threads; + + this->mod_.reserve(this->num_threads_); + for (int i = 0; i < this->num_threads_; i++) { + this->mod_.push_back(DeepCopyIRModule(tune_context->mod.value())); + } + + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); this->rand_state_ = ForkSeed(&tune_context->rand_state); this->state_.reset(); } @@ -118,7 +124,7 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure tir::Trace trace = design_spaces[design_space_index]->trace().value(); tir::Trace new_trace = tir::Trace(trace->insts, {}); tir::Schedule sch = tir::Schedule::Traced( // - self->mod_, // + self->mod_[thread_id], // /*rand_state=*/ForkSeed(&rand_state), // /*debug_mode=*/0, // /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); @@ -142,7 +148,8 @@ SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_tria } TVM_REGISTER_NODE_TYPE(ReplayTraceNode); -TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") + .set_body_typed(SearchStrategy::ReplayTrace); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc new file mode 100644 index 0000000000..41afbc57d7 --- /dev/null +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief Collecting all the non-root blocks */ +class BlockCollector : public tir::StmtVisitor { + public: + static Array Collect(const tir::Schedule& sch) { // + return BlockCollector(sch).Run(); + } + + private: + /*! \brief Entry point */ + Array Run() { + for (const auto& kv : sch_->mod()->functions) { + const GlobalVar& gv = kv.first; // `gv->name_hint` is the name of the function + const BaseFunc& base_func = kv.second; // this can be PrimFunc or relay::Function + if (const auto* func = base_func.as()) { + func_name_ = gv->name_hint; + block_names_.clear(); + blocks_to_collect_.clear(); + root_block_ = func->body.as()->block.get(); + VisitStmt(func->body); + for (const String& block_name : blocks_to_collect_) { + results_.push_back(sch_->GetBlock(block_name, func_name_)); + } + } + } + return results_; + } + /*! \brief Constructor */ + explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {} + /*! \brief Override the Stmt visiting behaviour */ + void VisitStmt_(const tir::BlockNode* block) override { + tir::StmtVisitor::VisitStmt_(block); + if (block != root_block_) { + CHECK(block_names_.count(block->name_hint) == 0) + << "Duplicated block name " << block->name_hint << " in function " << func_name_ + << " not supported!"; + block_names_.insert(block->name_hint); + blocks_to_collect_.push_back(block->name_hint); + } + } + + /*! \brief The schedule to be collected */ + const tir::Schedule& sch_; + /*! \brief The set of func name and block name pair */ + std::unordered_set block_names_; + /* \brief The list of blocks to collect in order */ + Array blocks_to_collect_; + /*! \brief Function name & blocks of collection */ + Array results_; + /*! \brief The root block of the PrimFunc */ + const tir::BlockNode* root_block_; + /*! \brief Name of the current PrimFunc */ + String func_name_; +}; + +/*! + * \brief Design Space Generator that generates design spaces by applying schedule rules to blocks + * in post-DFS order. + * */ +class PostOrderApplyNode : public SpaceGeneratorNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The schedule rules to be applied in order. */ + Array sch_rules_{nullptr}; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `rand_state_` is not visited + // `sch_rules_` is not visited + } + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->sch_rules_ = tune_context->sch_rules; + } + + Array GenerateDesignSpace(const IRModule& mod_) final { + using ScheduleAndUnvisitedBlocks = std::pair>; + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/mod_, // + /*rand_state=*/ForkSeed(&this->rand_state_), // + /*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail // + ); + + std::vector stack; + Array result{sch}; + // Enumerate the schedule rules first because you can + // always concat multiple schedule rules as one + for (ScheduleRule sch_rule : sch_rules_) { + for (const tir::Schedule& sch : result) { + stack.emplace_back(sch, BlockCollector::Collect(sch)); + } + result.clear(); + + while (!stack.empty()) { + // get the stack.top() + tir::Schedule sch; + Array blocks; + std::tie(sch, blocks) = stack.back(); + stack.pop_back(); + // if all blocks are visited + if (blocks.empty()) { + result.push_back(sch); + continue; + } + // otherwise, get the last block that is not visited + tir::BlockRV block_rv = blocks.back(); + blocks.pop_back(); + if (sch->HasBlock(block_rv)) { + Array applied = sch_rule->Apply(sch, /*block=*/block_rv); + for (const tir::Schedule& sch : applied) { + stack.emplace_back(sch, blocks); + } + } + } + } + return result; + } + static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; + TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); +}; + +SpaceGenerator SpaceGenerator::PostOrderApply() { + ObjectPtr n = make_object(); + return SpaceGenerator(n); +} + +TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") + .set_body_typed(SpaceGenerator::PostOrderApply); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index a529f2354d..2bd7cf4bcf 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -52,13 +52,17 @@ class RoundRobinNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::RoundRobin(Array tasks, Builder builder, Runner runner, - Database database) { +TaskScheduler TaskScheduler::RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Array measure_callbacks) { ObjectPtr n = make_object(); n->tasks = tasks; n->builder = builder; n->runner = runner; n->database = database; + n->measure_callbacks = measure_callbacks; n->task_id = -1; return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index cf0af3d55f..3bd2ab9dde 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -92,19 +92,40 @@ Array SendToRunner(const Runner& runner, // return results; } +void TaskSchedulerNode::InitializeTask(int task_id) { + TuneContext task = this->tasks[task_id]; + // Derive the values. + IRModule mod = task->mod.value(); + SpaceGenerator space = task->space_generator.value(); + SearchStrategy strategy = task->search_strategy.value(); + // Initialize Modules. + space->InitializeWithTuneContext(task); + strategy->InitializeWithTuneContext(task); + // Initialize the rules. + for (const ScheduleRule& sch_rule : task->sch_rules) { + sch_rule->InitializeWithTuneContext(task); + } + for (const Mutator& mutator : task->mutators) { + mutator->InitializeWithTuneContext(task); + } + for (const Postproc& postproc : task->postprocs) { + postproc->InitializeWithTuneContext(task); + } +} + void TaskSchedulerNode::Tune() { - for (const TuneContext& task : this->tasks) { - CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; - CHECK(task->space_generator.defined()) + for (int i = 0; i < static_cast(this->tasks.size()); i++) { + // Check Optional value validity. + CHECK(tasks[i]->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; + CHECK(tasks[i]->space_generator.defined()) << "ValueError: Require `context.space_generator`, but it is not defined"; - CHECK(task->search_strategy.defined()) + CHECK(tasks[i]->search_strategy.defined()) << "ValueError: Require `context.search_strategy`, but it is not defined"; - IRModule mod = task->mod.value(); - SpaceGenerator space = task->space_generator.value(); - SearchStrategy strategy = task->search_strategy.value(); - space->InitializeWithTuneContext(task); - strategy->InitializeWithTuneContext(task); - strategy->PreTuning(space->GenerateDesignSpace(mod)); + + InitializeTask(i); + + tasks[i]->search_strategy.value()->PreTuning( + tasks[i]->space_generator.value()->GenerateDesignSpace(tasks[i]->mod.value())); } int running_tasks = tasks.size(); @@ -114,7 +135,7 @@ void TaskSchedulerNode::Tune() { ICHECK(!task->is_stopped); ICHECK(!task->runner_futures.defined()); SearchStrategy strategy = task->search_strategy.value(); - if (task->measure_candidates = strategy->GenerateMeasureCandidates()) { + if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { Array builder_results = SendToBuilder(this->builder, task, task->measure_candidates.value()); task->runner_futures = @@ -186,13 +207,25 @@ void TaskSchedulerNode::JoinRunningTask(int task_id) { } TaskScheduler TaskScheduler::PyTaskScheduler( + Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Array measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // + PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, // PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // PyTaskSchedulerNode::FNextTaskId f_next_task_id) { ObjectPtr n = make_object(); + n->tasks = tasks; + n->builder = builder; + n->runner = runner; + n->database = database; + n->measure_callbacks = measure_callbacks; n->f_tune = f_tune; + n->f_initialize_task = f_initialize_task; n->f_set_task_stopped = f_set_task_stopped; n->f_is_task_running = f_is_task_running; n->f_join_running_task = f_join_running_task; @@ -202,14 +235,16 @@ TaskScheduler TaskScheduler::PyTaskScheduler( TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode); TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode); -TVM_REGISTER_GLOBAL("tvm.task.TaskSchedulerPyTaskScheduler") +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler") .set_body_typed(TaskScheduler::PyTaskScheduler); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") + .set_body_method(&TaskSchedulerNode::Tune); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerInitializeTask") + .set_body_method(&TaskSchedulerNode::InitializeTask); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerSetTaskStopped") .set_body_method(&TaskSchedulerNode::SetTaskStopped); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerIsTaskRunning") .set_body_method(&TaskSchedulerNode::IsTaskRunning); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") - .set_body_method(&TaskSchedulerNode::Tune); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") .set_body_method(&TaskSchedulerNode::JoinRunningTask); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 9fc9272e33..21ba8294fc 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -24,20 +24,13 @@ namespace tvm { namespace meta_schedule { -/*! - * \brief Constructor function of TuneContext class. - * \param mod The mod to be optimized. - * \param target The target to be optimized for. - * \param space_generator The design space generator. - * \param task_name The name of the tuning task. - * \param rand_state The random state. - * \param num_threads The number of threads to be used. - * \param verbose The verbosity level. - */ TuneContext::TuneContext(Optional mod, // Optional target, // Optional space_generator, // Optional search_strategy, // + Array sch_rules, // + Array postprocs, // + Array mutators, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { @@ -46,6 +39,9 @@ TuneContext::TuneContext(Optional mod, n->target = target; n->space_generator = space_generator; n->search_strategy = search_strategy; + n->sch_rules = sch_rules; + n->postprocs = postprocs; + n->mutators = mutators; n->task_name = task_name; if (rand_state == -1) { rand_state = std::random_device()(); @@ -65,11 +61,14 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") Optional target, // Optional space_generator, // Optional search_strategy, // + Array sch_rules, // + Array postprocs, // + Array mutators, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { - return TuneContext(mod, target, space_generator, search_strategy, task_name, rand_state, - num_threads); + return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs, + mutators, task_name, rand_state, num_threads); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 83e65a5ced..be76d3e8db 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -23,7 +23,11 @@ #include #include #include +#include +#include +#include #include +#include #include #include #include @@ -32,6 +36,7 @@ #include #include #include +#include #include #include @@ -193,7 +198,7 @@ inline support::LinearCongruentialEngine::TRandState ForkSeed( /*! * \brief Fork a random state into another ones, i.e. PRNG splitting. - * The given random state is also mutated. + * The given random state is also mutated. * \param rand_state The random state to be forked * \param n The number of forks * \return The forked random states @@ -208,6 +213,15 @@ inline std::vector ForkSeed( return results; } +/*! + * \brief Get deep copy of an IRModule. + * \param mod The IRModule to make a deep copy. + * \return The deep copy of the IRModule. + */ +inline IRModule DeepCopyIRModule(IRModule mod) { + return Downcast(LoadJSON(SaveJSON(mod))); +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index e1ed3d47d3..79bd975025 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -222,7 +222,8 @@ class TECompilerImpl : public TECompilerNode { auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule{nullptr}, + tir::PrimFunc{nullptr}, {}, ir_module); return value; } @@ -243,16 +244,19 @@ class TECompilerImpl : public TECompilerNode { return value; } } - - // NOTE: array will copy on write. - Array all_args = Array(cfunc->inputs); - for (te::Tensor arg : cfunc->outputs) { - all_args.push_back(arg); + if (cfunc->prim_func.defined()) { + cfunc->funcs->Update(cfunc->prim_fn_var, cfunc->prim_func.value()); + } else { + // NOTE: array will copy on write. + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { + all_args.push_back(arg); + } + // lower the function + std::unordered_map binds; + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); } - - std::unordered_map binds; - auto func_name = cfunc->prim_fn_var->name_hint; - cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); value->cached_func = cfunc; return value; } diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 3970b0e806..5bdc568b3f 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -70,7 +70,8 @@ CCacheKey::CCacheKey(Function source_func, Target target) { CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array inputs, tvm::Array outputs, te::Schedule schedule, - tvm::Array shape_func_param_states, IRModule funcs) { + tir::PrimFunc prim_func, tvm::Array shape_func_param_states, + IRModule funcs) { auto n = make_object(); n->target = target; n->prim_fn_var = prim_fn_var; @@ -117,11 +118,12 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + use_meta_schedule_ = backend::IsMetaScheduleEnabled(); } - CachedFunc Create(const Function& prim_func, std::function renamer) { + CachedFunc Create(const Function& relay_func, std::function renamer) { Array fn_inputs; - for (Var param : prim_func->params) { + for (Var param : relay_func->params) { Array inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); @@ -131,7 +133,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator memo_[param] = inputs; } readable_name_stream_ << "fused"; - auto outputs = this->VisitExpr(prim_func->body); + auto outputs = this->VisitExpr(relay_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME @@ -151,7 +153,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator prim_fn_name = renamer(prim_fn_name); } auto prim_fn_var = GlobalVar(prim_fn_name); - prim_fn_var->checked_type_ = prim_func->checked_type(); + prim_fn_var->checked_type_ = relay_func->checked_type(); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. @@ -163,7 +165,8 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } } - te::Schedule schedule; + te::Schedule schedule{nullptr}; + tir::PrimFunc prim_func{nullptr}; // No need to register schedule for device copy op. if (anchor_attrs_.as() == nullptr && create_schedule_) { if (use_auto_scheduler_) { @@ -176,20 +179,38 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator schedule = Downcast(obj); } } + if (use_meta_schedule_) { + const auto* f_create_func = runtime::Registry::Get("te.CreatePrimFuncFromOutputs"); + const auto* f_meta_schedule = + runtime::Registry::Get("meta_schedule.IntegrationContextEntryPoint"); + ICHECK(f_create_func) << "te.CreatePrimFuncFromOutputs is not registered"; + ICHECK(f_meta_schedule) << "meta_schedule.IntegrationContextEntryPoint is not registered"; + prim_func = (*f_create_func)(tensor_outs); + Optional opt_mod_or_base_func = + (*f_meta_schedule)(prim_fn_name, IRModule({{GlobalVar(prim_fn_name), relay_func}}), + Array{IRModule({{GlobalVar(prim_fn_name), prim_func}})}); + if (const auto* result = opt_mod_or_base_func.as()) { + prim_func = GetRef(result); + } else { + prim_func = tir::PrimFunc(nullptr); + } + } // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined()) { + if (!schedule.defined() && !prim_func.defined()) { ICHECK(anchor_implementation_.defined()); schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); } - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); + if (schedule.defined()) { + for (const auto& scalar : scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } } } } - return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {}); + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}); } Array VisitExpr_(const VarNode* op) final { @@ -336,6 +357,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator std::ostringstream readable_name_stream_; Array scalars_; bool use_auto_scheduler_; + bool use_meta_schedule_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; @@ -450,8 +472,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> std::unordered_map binds; IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); - return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states, - ir_module); + return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, tir::PrimFunc{nullptr}, + shape_func_param_states, ir_module); } Array VisitExpr(const Expr& expr) final { diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 7975ef8731..2171880fd6 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -129,16 +129,18 @@ class CCacheKey : public ObjectRef { /*! \brief Node container to represent a cached function. */ struct CachedFuncNode : public Object { - /* \brief compiled target */ + /*! \brief compiled target */ tvm::Target target; /*! \brief Primitive Function Name */ GlobalVar prim_fn_var; - /* \brief The inputs to the function */ + /*! \brief The inputs to the function */ tvm::Array inputs; - /* \brief The outputs to the function */ + /*! \brief The outputs to the function */ tvm::Array outputs; /*! \brief The schedule to the function */ te::Schedule schedule; + /*! \brief The TIR function if lowering in the meta schedule path */ + Optional prim_func; /*! \brief Parameter usage states in the shape function. */ tvm::Array shape_func_param_states; /*! \brief The lowered functions to support the function. */ @@ -150,6 +152,7 @@ struct CachedFuncNode : public Object { v->Visit("inputs", &inputs); v->Visit("outputs", &outputs); v->Visit("schedule", &schedule); + v->Visit("prim_func", &prim_func); v->Visit("funcs", &funcs); v->Visit("shape_func_param_states", &shape_func_param_states); } @@ -161,7 +164,7 @@ struct CachedFuncNode : public Object { class CachedFunc : public ObjectRef { public: CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array inputs, - tvm::Array outputs, te::Schedule schedule, + tvm::Array outputs, te::Schedule schedule, tir::PrimFunc prim_func, tvm::Array shape_func_param_states, IRModule funcs = IRModule(Map({}))); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index f89a099b0d..16cbe0e8db 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -427,6 +427,15 @@ inline bool IsAutoSchedulerEnabled() { .value(); } +/*! + * \brief Return whether the meta schedule is enabled in the pass context. + */ +inline bool IsMetaScheduleEnabled() { + return transform::PassContext::Current() + ->GetConfig("relay.backend.use_meta_schedule", Bool(false)) + .value(); +} + /*! * \brief Get the sequence of Relay optimization passes based on backend type. * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 657dc12196..d90681a1c0 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -22,6 +22,7 @@ #include #include +#include #include "../schedule/graph.h" @@ -300,9 +301,40 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { return (*complete)(func, info.root_alloc); } // namespace tir -TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed([](const Array& tensors) { - return CreatePrimFunc(tensors); -}); +PrimFunc CreatePrimFuncFromOutputs(const Array& outputs) { + std::vector stack; + std::unordered_set visited; + for (const te::Tensor& output : outputs) { + if (!visited.count(output.get())) { + visited.insert(output.get()); + stack.push_back(output); + } + } + + Array arg_list; + while (!stack.empty()) { + te::Tensor tensor = stack.back(); + stack.pop_back(); + if (tensor->op->IsInstance()) { + arg_list.push_back(tensor); + } else if (tensor->op->IsInstance()) { + Array inputs = tensor->op->InputTensors(); + for (const te::Tensor& input : inputs) { + if (!visited.count(input.get())) { + visited.insert(input.get()); + stack.push_back(input); + } + } + } + } + for (const te::Tensor& output : outputs) { + arg_list.push_back(output); + } + return CreatePrimFunc(arg_list); +} + +TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); +TVM_REGISTER_GLOBAL("te.CreatePrimFuncFromOutputs").set_body_typed(CreatePrimFuncFromOutputs); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d14d64a4c7..e3a535e9b3 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -505,8 +505,8 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, if (const ForNode* loop = p->StmtAs()) { if (loop->kind == ForKind::kThreadBinding) { const String& thread_tag = loop->thread_binding.value()->thread_tag; - if (CanRelaxStorageUndereThread(extra_relax_scope, - runtime::ThreadScope::Create(thread_tag))) { + if (CanRelaxStorageUnderThread(extra_relax_scope, + runtime::ThreadScope::Create(thread_tag))) { result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 42839075af..54760abbe5 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -232,6 +232,16 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, throw; } +Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, + int max_innermost_factor, + Optional> decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision)); + TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); + throw; +} + /******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { @@ -282,6 +292,24 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } +Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + Array result; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(block_rv), false)); + TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + +Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + Array result; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(loop_rv), false)); + TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + /******** Schedule: Transform loops ********/ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { @@ -435,6 +463,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +/******** Schedule: Data movement ********/ + +BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 1f9aeecfc7..d053e3329f 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -71,6 +71,7 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + inline bool HasBlock(const BlockRV& block_rv) const final; inline Array GetSRefs(const Array& rvs) const; inline Array GetSRefs(const Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } @@ -80,19 +81,15 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - /*! - * \brief Sample an integer given the probability distribution - * \param candidates The candidates - * \param probs The probability distribution of the candidates - * \param decision The sampling decision, if it's given we would validate the decision, otherwise - * we would sample a decision from the distribution and set the decision accordingly. - * \return The random variable sampled from candidates - */ ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) override; + Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, + Optional> decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; + Array GetChildBlocks(const BlockRV& block_rv) override; + Array GetChildBlocks(const LoopRV& loop_rv) override; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; @@ -107,6 +104,11 @@ class ConcreteScheduleNode : public ScheduleNode { const String& storage_scope) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -154,6 +156,12 @@ class ConcreteScheduleNode : public ScheduleNode { * \return The new random variable created */ inline ExprRV CreateRV(int64_t value); + /*! + * \brief Add a list of integers as random variables into the symbol table + * \param value The list of integers to be added to the symbol table + * \return The new random variables created + */ + inline Array CreateRV(const std::vector& value); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); }; @@ -187,6 +195,19 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { return this->analyzer_->Simplify(transformed); } +inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const { + auto it = this->symbol_table_.find(block_rv); + if (it == this->symbol_table_.end()) { + return false; + } + const ObjectRef& obj = (*it).second; + const auto* sref = obj.as(); + if (sref == nullptr || sref->stmt == nullptr) { + return false; + } + return true; +} + inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { auto it = this->symbol_table_.find(block_rv); if (it == this->symbol_table_.end()) { @@ -274,6 +295,15 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { return std::move(rv); } +inline Array ConcreteScheduleNode::CreateRV(const std::vector& value) { + Array results; + results.reserve(value.size()); + for (int64_t v : value) { + results.push_back(CreateRV(v)); + } + return results; +} + inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { auto it = this->symbol_table_.find(obj); if (it != this->symbol_table_.end()) { diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 057e845dbd..aa726c48b0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -32,8 +32,8 @@ namespace tir { * \param max_exclusive The maximum value of the range, exclusive. * \return The random integer sampled in the given range. */ -TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, - int max_exclusive); +TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, + int32_t min_inclusive, int32_t max_exclusive); /*! * \brief Sample once category from candidates according to the probability weights. * \param self The schedule to update @@ -46,6 +46,25 @@ TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); +/*! + * \brief Sample the factors to perfect tile a specific loop + * \param rand_state The random state + * \param loop_sref The loop to be tiled + * \param n The number of tiles to be sampled + * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop + * \param decision The sampling decision + * \return A list of length `n`, the random perfect tile sizes sampled + */ +TVM_DLL std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + int32_t extent, int32_t n_splits); +TVM_DLL std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + int32_t extent, int32_t n_split, int32_t max_innermost_factor); +TVM_DLL std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, + Optional>* decision); /******** Schedule: Get blocks & loops ********/ /*! @@ -63,6 +82,15 @@ Array GetBlocks(const ScheduleState& self, const String& name, const S * \return A list of loops above the given block in its scope, from outer to inner */ Array GetLoops(const StmtSRef& block_sref); +/*! + * \brief Get the leaf blocks of a specific block/loop + * \param self The schedule state + * \param parent_sref The query block/loop + * \param inclusive Whether to include parent_sref + * \return A list of leaf blocks inside a specific block/loop + */ +Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref, + bool inclusive = false); /******** Schedule: Transform loops ********/ /*! * Split a loop into a list of consecutive loops. It requires: @@ -168,6 +196,15 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); + +/******** Schedule: Data movement ********/ + +TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope); + +TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope); + /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 8b32a9c14f..4835c0854c 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -55,6 +55,31 @@ Array GetLoops(const StmtSRef& block_sref) { return {result.rbegin(), result.rend()}; } +Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref, + bool inclusive) { + struct Collector : public StmtVisitor { + private: + void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); } + + public: + explicit Collector(const ScheduleState& self) : self(self) {} + + const ScheduleState& self; + Array result; + }; + Collector collector(self); + if (inclusive) { + collector(GetRef(parent_sref->stmt)); + } else if (parent_sref->stmt->IsInstance()) { + const auto* loop = static_cast(parent_sref->stmt); + collector(loop->body); + } else if (parent_sref->stmt->IsInstance()) { + const auto* block = static_cast(parent_sref->stmt); + collector(block->body); + } + return std::move(collector.result); +} + /******** InstructionKind Registration ********/ struct GetBlockTraits : public UnpackedInstTraits { @@ -106,8 +131,39 @@ struct GetLoopsTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct GetChildBlocksTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetChildBlocks"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { + if (const auto* block = block_or_loop_rv.as()) { + return sch->GetChildBlocks(GetRef(block)); + } + if (const auto* loop = block_or_loop_rv.as()) { + return sch->GetChildBlocks(GetRef(loop)); + } + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + throw; + } + + static String UnpackedAsPython(Array outputs, String block_or_loop_rv) { + PythonAPICall py("get_child_blocks"); + py.Input("", block_or_loop_rv); + py.OutputList(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc new file mode 100644 index 0000000000..cb693c77cd --- /dev/null +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../utils.h" +#include "tvm/runtime/memory.h" +#include "tvm/runtime/object.h" +#include "tvm/tir/schedule/block_scope.h" +#include "tvm/tir/stmt_functor.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + return true; + } + } + return false; +} + +void RelaxBufferRegions(const Array& buffer_regions, + const Buffer& buffer, // + const Map& var_dom, // + const Map& bindings, // + std::vector* relaxed_regions) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + Array relaxed_region = + arith::EvalSet(Substitute(buffer_region->region, bindings), var_dom); + relaxed_regions->push_back({relaxed_region.begin(), relaxed_region.end()}); + } + } +} + +class ScopeReplacer : public StmtMutator { + public: + static Block Replace(const BlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, + const ForNode* new_loop) { + ObjectPtr new_scope_block = make_object(*scope_block); + new_scope_block->body = ScopeReplacer(old_loop, new_loop)(std::move(new_scope_block->body)); + new_scope_block->alloc_buffers.push_back(dst); + return Block(new_scope_block); + } + + private: + explicit ScopeReplacer(const ForNode* old_loop, const ForNode* new_loop) + : old_loop_(old_loop), new_loop_(new_loop), found_(false) {} + + Stmt VisitStmt(const Stmt& stmt) final { return found_ ? stmt : StmtMutator::VisitStmt(stmt); } + Stmt VisitStmt_(const BlockNode* block) final { return GetRef(block); } + Stmt VisitStmt_(const ForNode* loop) final { + if (loop == old_loop_) { + found_ = true; + return GetRef(new_loop_); + } + return StmtMutator::VisitStmt_(loop); + } + + const ForNode* old_loop_; + const ForNode* new_loop_; + bool found_; +}; + +class BufferReplacer : public StmtExprMutator { + public: + explicit BufferReplacer(const Buffer& src, const Buffer& dst, Map* block_sref_reuse) + : src_(src), dst_(dst), block_sref_reuse_(block_sref_reuse) {} + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + return BufferStore(new_store); + } + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + return BufferLoad(new_load); + } + return load; + } + + Stmt VisitStmt_(const BlockNode* _block) final { + Block old_block = GetRef(_block); + Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); + ObjectPtr new_block = make_object(*block.get()); + new_block->reads = ReplaceBuffer(new_block->reads, src_, dst_); + new_block->writes = ReplaceBuffer(new_block->writes, src_, dst_); + block_sref_reuse_->Set(old_block, Block(new_block)); + return Block(new_block); + } + + const Buffer& src_; + const Buffer& dst_; + Map* block_sref_reuse_; +}; + +struct ReadWriteAtImpl { + template + static StmtSRef Main(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope, + Map annotations) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer src = + GetNthAccessBuffer(self, GetRef(block), buffer_index, /*is_write=*/!is_read); + Buffer dst = WithScope(src, storage_scope); + ReadWriteAtImpl impl(self, loop_sref, src, dst, annotations); + std::pair new_loop_block = + impl.MakeLoopAndBlock(src->name + "_" + storage_scope); + StmtSRef result_block_sref = + impl.ReplaceScopeBlock(new_loop_block.first.get(), new_loop_block.second->block.get()); + impl.UpdateBlockInfo(result_block_sref); + return result_block_sref; + } + + private: + static Map GetLoopDomain(const StmtSRefNode* loop_sref) { + Map result; + for (const ForNode* loop; (loop = loop_sref->StmtAs()) != nullptr; + loop_sref = loop_sref->parent) { + result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return result; + } + + StmtSRef ReplaceScopeBlock(const ForNode* new_loop, const BlockNode* new_block) { + StmtSRef scope_root_sref = GetScopeRoot(self_, loop_sref_, + /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/false); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_root_sref); + Block new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); + block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + self_->Replace(scope_root_sref, new_scope_block, block_sref_reuse_); + return self_->stmt2ref.at(new_block); + } + + void UpdateBlockInfo(const StmtSRef& new_block_sref) { + BlockInfo& block_info = self_->block_info[new_block_sref]; + block_info.affine_binding = true; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + } + + template + std::pair MakeLoopAndBlock(const String& new_block_name_hint) { + Array subtrees = AsArray(loop_->body); + int n_subtrees = subtrees.size(); + runtime::StorageScope scope = runtime::StorageScope::Create(dst_.scope()); + std::vector relaxed_regions; + std::vector r_pos; + std::vector w_pos; + relaxed_regions.reserve(n_subtrees); + r_pos.reserve(n_subtrees); + w_pos.reserve(n_subtrees); + // Step 1. Iterate over all subtrees + for (int i = 0; i < n_subtrees; ++i) { + bool r_visited = false; + bool w_visited = false; + auto f_visit = [this, &relaxed_regions, &r_visited, &w_visited, + &scope](const ObjectRef& obj) -> bool { + const BlockRealizeNode* realize = obj.as(); + if (realize == nullptr) { + return true; + } + const BlockNode* block = realize->block.get(); + bool has_r = HasBuffer(block->reads, src_); + bool has_w = HasBuffer(block->writes, src_); + r_visited = r_visited || has_r; + w_visited = w_visited || has_w; + if (is_read ? has_r : has_w) { + RelaxBufferRegions( + /*buffer_regions=*/is_read ? block->reads : block->writes, + /*buffer=*/src_, + /*var_dom=*/ + AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/GetRef(self_->stmt2ref.at(block)->parent), + /*high_exclusive=*/loop_sref_, + /*extra_relax_scope=*/scope)), + /*bindings=*/GetBindings(GetRef(realize)), + /*relaxed_regions=*/&relaxed_regions); + } + return false; + }; + PreOrderVisit(subtrees[i], f_visit); + if (r_visited) { + r_pos.push_back(i); + } + if (w_visited) { + w_pos.push_back(i); + } + } + // Step 2. Calculate `insert_pos` and [st, ed) for buffer replacement + int insert_pos = -1, st = -1, ed = -1; + if (is_read) { + ICHECK(!r_pos.empty()); + // No write after the first read + ICHECK(w_pos.empty() || w_pos.back() < r_pos.front()); + // Can be inserted at [0, r_pos.front()], i.e. before the first read + insert_pos = r_pos.front(); + // Buffer reads in [insert_pos, +oo) is rewritten + st = insert_pos; + ed = n_subtrees; + } else { + ICHECK(!w_pos.empty()); + // No read after the last write + ICHECK(r_pos.empty() || r_pos.back() <= w_pos.back()); + // Can be inserted into (w_pos.back(), +oo), i.e. after the last write + insert_pos = w_pos.back() + 1; + st = 0; + ed = insert_pos; + } + // Step 3. Calculate `domain`, the domain of buffer access + NDIntSet relaxed = support::NDIntSetUnion(relaxed_regions); + int ndim = relaxed.size(); + Array domain; + domain.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + const arith::IntSet& int_set = relaxed[i]; + PrimExpr min = analyzer_->Simplify(int_set.min()); + PrimExpr extent = analyzer_->Simplify(int_set.max() + 1 - min); + domain.push_back(Range::FromMinExtent(min, extent)); + } + // Step 4. Insert the auto copy block and replace buffers + BufferReplacer replacer(src_, dst_, &block_sref_reuse_); + for (int i = st; i < ed; ++i) { + Stmt stmt = subtrees[i]; + subtrees.Set(i, Stmt(nullptr)); + subtrees.Set(i, replacer(std::move(stmt))); + } + BlockRealize realize = + is_read + ? MakeBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) + : MakeBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); + subtrees.insert(subtrees.begin() + insert_pos, realize); + ObjectPtr new_loop = make_object(*loop_); + new_loop->body = SeqStmt(std::move(subtrees)); + return {For(new_loop), realize}; + } + + BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, const String& name_hint, + const Map& loop_domain, Array domain) const { + int n = domain.size(); + std::vector loop_vars; + loop_vars.reserve(n); + for (int i = 0; i < n; ++i) { + loop_vars.push_back(Var("ax" + std::to_string(i))); + } + Map bindings; + Array iter_vars; + Array iter_values; + Array indices; + iter_vars.reserve(n); + iter_values.reserve(n); + indices.reserve(n); + for (int i = 0; i < n; ++i) { + auto f_substitute = [&loop_domain, &bindings, &iter_vars, + &iter_values](const Var& var) -> Optional { + auto it = bindings.find(var); + if (it != bindings.end()) { + return (*it).second; + } + Range range = loop_domain.at(var); + ObjectPtr v = make_object(*var.get()); + v->name_hint = "v" + std::to_string(iter_vars.size()); + bindings.Set(var, Var(v)); + iter_values.push_back(var); + iter_vars.push_back(IterVar(range, Var(v), IterVarType::kDataPar)); + return Var(v); + }; + ObjectPtr dom = make_object(*domain[i].get()); + dom->min = Substitute(std::move(dom->min), f_substitute); + dom->extent = Substitute(std::move(dom->extent), f_substitute); + domain.Set(i, Range(dom)); + } + for (int i = 0; i < n; ++i) { + indices.push_back(domain[i]->min + loop_vars[i]); + } + Stmt stmt = BufferStore(copy_to, /*value=*/BufferLoad(copy_from, indices), /*indices=*/indices); + for (int i = n - 1; i >= 0; --i) { + stmt = For(loop_vars[i], Integer(0), domain[i]->extent, ForKind::kSerial, stmt); + } + return BlockRealize( + /*values=*/iter_values, + /*predicate=*/const_true(), + Block(/*iter_vars=*/iter_vars, + /*reads=*/{BufferRegion(copy_from, domain)}, + /*writes=*/{BufferRegion(copy_to, domain)}, + /*name_hint=*/name_hint, // + /*body=*/std::move(stmt), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations_)); + } + + explicit ReadWriteAtImpl(ScheduleState self, const StmtSRef& loop_sref, const Buffer& src, + const Buffer& dst, Map annotations) + : self_(self), + loop_sref_(loop_sref), + loop_(nullptr), + src_(src), + dst_(dst), + annotations_(annotations), + block_sref_reuse_(), + analyzer_(std::make_unique()) { + loop_ = TVM_SREF_TO_FOR(loop_, loop_sref); + } + + ScheduleState self_; + const StmtSRef& loop_sref_; + const ForNode* loop_; + const Buffer& src_; + const Buffer& dst_; + Map annotations_; + Map block_sref_reuse_; + std::unique_ptr analyzer_; +}; + +StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, read_buffer_index, storage_scope, + {{"auto_copy", Integer(1)}}); +} + +StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, write_buffer_index, + storage_scope, {{"auto_copy", Integer(1)}}); +} + +/******** Instruction Registration ********/ + +struct ReadAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReadAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope); + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer read_buffer_index, String storage_scope) { + return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer read_buffer_index, String storage_scope) { + PythonAPICall py("read_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct WriteAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "WriteAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer write_buffer_index, String storage_scope) { + return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer write_buffer_index, String storage_scope) { + PythonAPICall py("write_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ReadAtTraits); +TVM_REGISTER_INST_KIND_TRAITS(WriteAtTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 6ac6226118..4acf618601 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -20,30 +20,156 @@ #include #include "../utils.h" +#include "tvm/support/random_engine.h" namespace tvm { namespace tir { -int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, - int max_exclusive) { +struct PrimeTable { + /*! \brief The table contains prime numbers in [2, kMaxPrime) */ + static constexpr const int32_t kMaxPrime = 65536; + /*! \brief The exact number of prime numbers in the table */ + static constexpr const int32_t kNumPrimes = 6542; + /*! + * \brief For each number in [2, kMaxPrime), the index of its min factor. + * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. + */ + int32_t min_factor_idx[kMaxPrime]; + /*! \brief The prime numbers in [2, kMaxPrime) */ + std::vector primes; + /*! + * \brief The power of each prime number. + * pow_table[i, j] stores the result of pow(prime[i], j + 1) + */ + std::vector> pow_tab; + + /*! \brief Get a global instance of the prime table */ + static const PrimeTable* Global() { + static const PrimeTable table; + return &table; + } + + /*! \brief Constructor, pre-computes all info in the prime table */ + PrimeTable() { + constexpr const int64_t int_max = std::numeric_limits::max(); + // Euler's sieve: prime number in linear time + for (int32_t i = 0; i < kMaxPrime; ++i) { + min_factor_idx[i] = -1; + } + primes.reserve(kNumPrimes); + for (int32_t x = 2; x < kMaxPrime; ++x) { + if (min_factor_idx[x] == -1) { + min_factor_idx[x] = primes.size(); + primes.push_back(x); + } + for (size_t i = 0; i < primes.size(); ++i) { + int64_t factor = primes[i]; + int64_t y = x * factor; + if (y >= kMaxPrime) { + break; + } + min_factor_idx[y] = i; + if (x % factor == 0) { + break; + } + } + } + ICHECK_EQ(static_cast(primes.size()), static_cast(kNumPrimes)); + // Calculate the power table for each prime number + pow_tab.reserve(primes.size()); + for (int32_t prime : primes) { + std::vector tab; + tab.reserve(32); + for (int64_t pow = prime; pow <= int_max; pow *= prime) { + tab.push_back(pow); + } + tab.shrink_to_fit(); + pow_tab.emplace_back(std::move(tab)); + } + } + /*! + * \brief Factorize a number n, and return in a cryptic format + * \param n The number to be factorized + * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] + * For each pair (i, j), we define + * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) + * (primes[i], j) if i != -1 + * Then the factorization is + * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) + */ + std::vector> Factorize(int32_t n) const { + std::vector> result; + result.reserve(16); + int32_t i = 0, n_primes = primes.size(); + // Phase 1: n >= kMaxPrime + for (int32_t j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + if (j != 0) { + result.emplace_back(i, j); + } + } + // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number + if (n >= kMaxPrime) { + result.emplace_back(-1, n); + return result; + } + // Phase 2: n < kMaxPrime + for (int32_t j; n > 1;) { + int32_t i = min_factor_idx[n]; + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + result.emplace_back(i, j); + } + return result; + } +}; + +int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, + int32_t max_exclusive) { CHECK(min_inclusive < max_exclusive) << "ValueError: max_exclusive must be greater than min_inclusive."; if (min_inclusive + 1 == max_exclusive) { return min_inclusive; } support::LinearCongruentialEngine rand_(rand_state); - std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); + std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); return dist(rand_); } +std::vector SampleWithoutReplacement( + support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k) { + if (k == 1) { + return {SampleInt(rand_state, 0, n)}; + } + if (k == 2) { + int32_t result0 = SampleInt(rand_state, 0, n); + int32_t result1 = SampleInt(rand_state, 0, n - 1); + if (result1 >= result0) { + result1 += 1; + } + return {result0, result1}; + } + std::vector order(n); + for (int32_t i = 0; i < n; ++i) { + order[i] = i; + } + for (int32_t i = 0; i < k; ++i) { + int32_t j = SampleInt(rand_state, i, n); + if (i != j) { + std::swap(order[i], order[j]); + } + } + return {order.begin(), order.begin() + k}; +} + int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; - int i = -1; - int n = candidates.size(); - + int32_t i = -1; + int32_t n = candidates.size(); if (decision->defined()) { const auto* int_imm = decision->as(); i = int_imm->value; @@ -51,7 +177,7 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st << ", but decision is: " << i; } else { std::vector weights = support::AsVector(probs); - std::discrete_distribution dist(weights.begin(), weights.end()); + std::discrete_distribution dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n @@ -62,6 +188,151 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } +std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, + int32_t extent, int32_t n_splits) { + CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; + CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; + // Handle special case that we can potentially accelerate + if (n_splits == 1) { + return {extent}; + } + if (extent == 1) { + return std::vector(n_splits, 1); + } + // Enumerate each pair (i, j), we define + // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) + // (primes[i], j) if i != -1 + // Then the factorization is + // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) + const PrimeTable* prime_tab = PrimeTable::Global(); + std::vector> factorized = prime_tab->Factorize(extent); + if (n_splits == 2) { + // n_splits = 2, this can be taken special care of, + // because general reservoir sampling can be avoided to accelerate the sampling + int32_t result0 = 1; + int32_t result1 = 1; + for (const std::pair& ij : factorized) { + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + (SampleInt(rand_state, 0, 2) ? result1 : result0) *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int32_t p = ij.second; + const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; + int32_t x1 = SampleInt(rand_state, 0, p + 1); + int32_t x2 = p - x1; + if (x1 != 0) { + result0 *= pow[x1]; + } + if (x2 != 0) { + result1 *= pow[x2]; + } + } + return {result0, result1}; + } + // Data range: + // 2 <= extent <= 2^31 - 1 + // 3 <= n_splits <= max tiling splits + // 1 <= p <= 31 + std::vector result(n_splits, 1); + for (const std::pair& ij : factorized) { + // Handle special cases to accelerate sampling + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + result[SampleInt(rand_state, 0, n_splits)] *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int32_t p = ij.second; + if (p == 1) { + result[SampleInt(rand_state, 0, n_splits)] *= prime_tab->primes[ij.first]; + continue; + } + // The general case. We have to sample uniformly from the solution of: + // x_1 + x_2 + ... + x_{n_splits} = p + // where x_i >= 0 + // Data range: + // 2 <= p <= 31 + // 3 <= n_splits <= max tiling splits + std::vector sampled = + SampleWithoutReplacement(rand_state, p + n_splits - 1, n_splits - 1); + std::sort(sampled.begin(), sampled.end()); + sampled.push_back(p + n_splits - 1); + const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; + for (int32_t i = 0, last = -1; i < n_splits; ++i) { + int32_t x = sampled[i] - last - 1; + last = sampled[i]; + if (x != 0) { + result[i] *= pow[x]; + } + } + } + return result; +} + +std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, + int32_t extent, int32_t n_splits, + int32_t max_innermost_factor) { + if (max_innermost_factor == -1) { + return SamplePerfectTile(rand_state, extent, n_splits); + } + CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; + std::vector innermost_candidates; + innermost_candidates.reserve(max_innermost_factor); + for (int32_t i = 1; i <= max_innermost_factor; ++i) { + if (extent % i == 0) { + innermost_candidates.push_back(i); + } + } + // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. + // We should do multiple factorization to weight the choices. However, it would lead to slower + // sampling speed. On the other hand, considering potential tricks we might do on the innermost + // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add + // more heuristics in the future + int32_t innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; + std::vector result = SamplePerfectTile(rand_state, extent / innermost, n_splits - 1); + result.push_back(innermost); + return result; +} + +std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, + Optional>* decision) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + int64_t extent = GetLoopIntExtent(loop); + std::vector result; + if (extent == -1) { + // Case 1. Handle loops with non-constant length + result = std::vector(n_splits, 1); + result[0] = -1; + } else if (decision->defined()) { + // Case 2. Use previous decision + result = support::AsVector(decision->value()); + int n = result.size(); + ICHECK_GE(n, 2); + int64_t len = extent; + for (int i = n - 1; i > 0; --i) { + int64_t& l = result[i]; + // A previous decision could become invalid because of the change of outer tiles + // To handle this case properly, we check if the tiling strategy is still perfect. + // If not, we use a trivial default solution (1, 1, ..., 1, L) for rest of the tiles + if (len % l != 0) { + l = len; + } + len /= l; + } + result[0] = len; + } else { + // Case 3. Use fresh new sampling result + result = SamplePerfectTile(rand_state, extent, n_splits, max_innermost_factor); + ICHECK_LE(result.back(), max_innermost_factor); + } + *decision = support::AsArray(result); + return result; +} + /******** InstructionKind Registration ********/ struct SampleCategoricalTraits : public UnpackedInstTraits { @@ -96,7 +367,38 @@ struct SampleCategoricalTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SamplePerfectTile"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 1; + + static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, + Integer max_innermost_factor, + Optional> decision) { + return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, Integer n, + Integer max_innermost_factor, Optional> decision) { + PythonAPICall py("sample_perfect_tile"); + py.Input("loop", loop_rv); + py.Input("n", n->value); + py.Input("max_innermost_factor", max_innermost_factor->value); + py.Decision(decision); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); +TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 84a37c392e..8f7caa9145 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -123,11 +123,25 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") /******** (FFI) Sampling ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") + .set_body_method(&ScheduleNode::SamplePerfectTile); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") + .set_body_typed([](Schedule self, ObjectRef rv) { + if (const auto* block_rv = rv.as()) { + return self->GetChildBlocks(GetRef(block_rv)); + } + if (const auto* loop_rv = rv.as()) { + return self->GetChildBlocks(GetRef(loop_rv)); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + }); /******** (FFI) Transform loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); @@ -145,6 +159,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +/******** (FFI) Data movement ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt") + .set_body_method(&ScheduleNode::WriteAt); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index cc48f2b9e7..94f15d5c65 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -43,6 +43,7 @@ Schedule TracedScheduleNode::Copy() const { } /******** Schedule: Sampling ********/ + ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { @@ -57,6 +58,21 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, return result; } +Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, + int max_innermost_factor, + Optional> decision) { + Array results = CreateRV(tir::SamplePerfectTile( + &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); + + static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{loop_rv}, + /*attrs=*/{Integer(n), Integer(max_innermost_factor)}, + /*outputs=*/{results.begin(), results.end()}), + /*decision=*/decision); + return results; +} + /******** Schedule: Get blocks & loops ********/ BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { @@ -81,6 +97,28 @@ Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { return results; } +Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + Array results = ConcreteScheduleNode::GetChildBlocks(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + Array results = ConcreteScheduleNode::GetChildBlocks(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + /******** Schedule: Transform loops ********/ LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs) { @@ -190,6 +228,31 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("ReadAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("WriteAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index fae5ca8608..d5676f4cdc 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,20 +47,15 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ - /*! - * \brief Sample an integer given the probability distribution - * \param candidates The candidates - * \param probs The probability distribution of the candidates - * \param decision The sampling decision, if it's given we would validate the decision, otherwise - * we would sample a decision from the distribution and set the decision accordingly. - * \return The random variable sampled from candidates - */ ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) final; - + Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, + Optional> decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; Array GetLoops(const BlockRV& block_rv) final; + Array GetChildBlocks(const BlockRV& block_rv) final; + Array GetChildBlocks(const LoopRV& loop_rv) final; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; @@ -75,6 +70,11 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index a63a9f0796..c66c2ca766 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -53,8 +53,8 @@ namespace tir { * \brief A helper macro to convert an sref to the statement it points to, * then check if the downcasting succeeded. * \param Result The result variable, used for checking - * \param SRef The SRef to be casted - * \param Type The type to be casted to, can be Block or For + * \param SRef The SRef to be cast + * \param Type The type to be cast to, can be Block or For */ #define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \ SRef->StmtAs(); \ @@ -64,7 +64,7 @@ namespace tir { * \brief A helper macro to convert an sref to the block it points to, * throwing an internal error if downcasting fails * \param Result The result variable, used for checking - * \param SRef The SRef to be casted + * \param SRef The SRef to be cast */ #define TVM_SREF_TO_BLOCK(Result, SRef) \ TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::BlockNode) \ @@ -75,7 +75,7 @@ namespace tir { * \brief A helper macro to convert an sref to the for-loop it points to, * throwing an internal error if downcasting fails * \param Result The name of the result variable, used for checking - * \param SRef The SRef to be casted + * \param SRef The SRef to be cast */ #define TVM_SREF_TO_FOR(Result, SRef) \ TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::ForNode) \ @@ -86,8 +86,8 @@ namespace tir { * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, * then check if the downcasting succeeded. * \param Result The result variable, used for checking - * \param From The ObjectRef to be downcasted - * \param Type The type to be downcasted to + * \param From The ObjectRef to be downcast + * \param Type The type to be downcast to */ #define TVM_TYPE_AS_OR_ERR(Result, From, Type) \ From.as(); \ @@ -97,8 +97,8 @@ namespace tir { * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, * throwing an internal error if downcast fails. * \param Result The result variable, used for checking - * \param From The ObjectRef to be downcasted - * \param Type The type to be downcasted to + * \param From The ObjectRef to be downcast + * \param Type The type to be downcast to */ #define TVM_TYPE_AS(Result, From, Type) \ TVM_TYPE_AS_OR_ERR(Result, From, Type) \ @@ -129,8 +129,8 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { * \param thread_scope The thread scope to be relaxed * \return A boolean indicating the result */ -inline bool CanRelaxStorageUndereThread(const runtime::StorageScope& storage_scope, - const runtime::ThreadScope& thread_scope) { +inline bool CanRelaxStorageUnderThread(const runtime::StorageScope& storage_scope, + const runtime::ThreadScope& thread_scope) { if (storage_scope.rank == runtime::StorageRank::kWarp) { // for warp memory, we only relax threadIdx.x return thread_scope.rank == 1 && thread_scope.dim_index == 0; @@ -210,6 +210,28 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } +/**************** Loop extents ****************/ + +/*! + * \brief Get the extents of a loop + * \param loop The loop to be queried + * \return The extents of the loop + */ +inline int64_t GetLoopIntExtent(const ForNode* loop) { + const auto* int_extent = loop->extent.as(); + return int_extent ? int_extent->value : -1; +} + +/*! + * \brief Get the extents of a loop + * \param loop_sref The loop to be queried + * \return The extents of the loop + */ +inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + return GetLoopIntExtent(loop); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index a1f488f386..36f0a3488c 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -232,7 +232,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { const String& thread_tag = loop->thread_binding.value()->thread_tag; // When there is warp memory // threadIdx.x must be set to be warp index. - return CanRelaxStorageUndereThread(scope, runtime::ThreadScope::Create(thread_tag)); + return CanRelaxStorageUnderThread(scope, runtime::ThreadScope::Create(thread_tag)); } /**************** Class members ****************/ diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py new file mode 100644 index 0000000000..e7217ea000 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import re +from typing import List + +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.meta_schedule.runner.runner import Runner +from tvm.meta_schedule.task_scheduler.task_scheduler import TaskScheduler +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T + +from tvm.meta_schedule.measure_callback import PyMeasureCallback +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.builder import BuilderResult +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.utils import _get_hex_address + +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_measure_callback(): + class FancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + tasks: List[TuneContext], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + assert len(measure_candidates) == 1 + assert_structural_equal(measure_candidates[0].sch.mod, Matmul) + assert ( + len(builds) == 1 + and builds[0].error_msg is None + and builds[0].artifact_path == "test_build" + ) + assert ( + len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2 + ) + return True + + measure_callback = FancyMeasureCallback() + assert measure_callback.apply( + TaskScheduler(), + [], + [MeasureCandidate(Schedule(Matmul), None)], + [BuilderResult("test_build", None)], + [RunnerResult([1.0, 2.1], None)], + ) + + +def test_meta_schedule_measure_callback_fail(): + class FailingMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + tasks: List[TuneContext], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + return False + + measure_callback = FailingMeasureCallback() + assert not measure_callback.apply( + TaskScheduler(), + [], + [MeasureCandidate(None, None)], + [BuilderResult(None, None)], + [RunnerResult(None, None)], + ) + + +def test_meta_schedule_measure_callback_as_string(): + class NotSoFancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: "TaskScheduler", + tasks: List["TuneContext"], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + pass + + def __str__(self) -> str: + return f"NotSoFancyMeasureCallback({_get_hex_address(self.handle)})" + + measure_callback = NotSoFancyMeasureCallback() + pattern = re.compile(r"NotSoFancyMeasureCallback\(0x[a-f|0-9]*\)") + assert pattern.match(str(measure_callback)) + + +if __name__ == "__main__": + test_meta_schedule_measure_callback() + test_meta_schedule_measure_callback_fail() + test_meta_schedule_measure_callback_as_string() diff --git a/tests/python/unittest/test_meta_schedule_mutator.py b/tests/python/unittest/test_meta_schedule_mutator.py new file mode 100644 index 0000000000..b4d94dc9a8 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List, Optional + +import re + +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.script import tir as T + +from tvm.meta_schedule.mutator import PyMutator +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address +from tvm.tir.schedule import Schedule, Trace + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_mutator(): + class FancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + mutator = FancyMutator() + sch = Schedule(Matmul) + res = mutator.apply(sch.trace) + assert res is not None + new_sch = sch.copy() + res.apply_to_schedule(new_sch, remove_postproc=True) + assert_structural_equal(sch.mod, new_sch.mod) + + +def test_meta_schedule_mutator_as_string(): + class YetAnotherFancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + pass + + def __str__(self) -> str: + return f"YetAnotherFancyMutator({_get_hex_address(self.handle)})" + + mutator = YetAnotherFancyMutator() + pattern = re.compile(r"YetAnotherFancyMutator\(0x[a-f|0-9]*\)") + assert pattern.match(str(mutator)) + + +if __name__ == "__main__": + test_meta_schedule_mutator() + test_meta_schedule_mutator_as_string() diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py new file mode 100644 index 0000000000..95b5ed002b --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -0,0 +1,340 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List +import pytest +import math +import sys + +import tvm +from tvm._ffi.base import TVMError, py2cerror +from tvm.ir.base import assert_structural_equal +from tvm.script import tir as T +from tvm.tir.schedule import Schedule, BlockRV, block_scope +from tvm.target import Target + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.meta_schedule.utils import _get_hex_address +from tvm.tir.schedule import trace +from tvm.tir.schedule.trace import Trace + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class DuplicateMatmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class TrinityMatmul: + @T.prim_func + def main(a: T.handle, d: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.alloc_buffer((1024, 1024), "float32") + C = T.alloc_buffer((1024, 1024), "float32") + D = T.match_buffer(d, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(1024, 1024): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 3.0 + for i, j in T.grid(1024, 1024): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = C[vi, vj] * 5.0 + + +@tvm.script.ir_module +class TrinityMatmulProcessedForReference: + @T.prim_func + def main(a: T.handle, d: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024], dtype="float32") + D = T.match_buffer(d, [1024, 1024], dtype="float32") + # body + # with tir.block("root") + B = T.alloc_buffer([1024, 1024], dtype="float32") + for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): + with T.block("A"): + vi = T.axis.S(1024, i0_0 * 64 + i0_1) + vj = T.axis.S(1024, i1_0 * 16 + i1_1) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): + with T.block("C"): + vi = T.axis.S(1024, i0_0 * 64 + i0_1) + vj = T.axis.S(1024, i1_0 * 16 + i1_1) + T.reads([B[vi, vj]]) + T.writes([D[vi, vj]]) + D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5) + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +class WowSoFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + return [new_sch] + + +class DoubleScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + result = [new_sch] + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + result.append(new_sch) + return result + + +class ReorderScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + new_sch = sch.copy() + i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) + new_sch.reorder(i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, i_0, j_0) + result = [new_sch] + new_sch = sch.copy() + i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) + new_sch.reorder(i_1, j_3, i_0, j_0, j_1, k_0, i_2, j_2, k_1, i_3) + result.append(new_sch) + return result + + +def test_meta_schedule_post_order_apply(): + mod = Matmul + context = TuneContext( + mod=mod, target=Target("llvm"), task_name="Test Task", sch_rules=[WowSoFancyScheduleRule()] + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 1 + try: + tvm.ir.assert_structural_equal(mod, schs[0].mod) + raise Exception("The schedule rule did not change the schedule.") + except (ValueError): + _check_correct(schs[0]) + + +def test_meta_schedule_post_order_apply_double(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Double Rules Task", + sch_rules=[DoubleScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 2 + for sch in schs: + try: + tvm.ir.assert_structural_equal(mod, sch.mod) + raise Exception("The schedule rule did not change the schedule.") + except (ValueError): + _check_correct(sch) + + +def test_meta_schedule_post_order_apply_multiple(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Double Rules Task", + sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 4 + for sch in schs: + try: + tvm.ir.assert_structural_equal(mod, sch.mod) + raise Exception("The schedule rule did not change the schedule.") + except (ValueError): + _check_correct(sch) + + +def test_meta_schedule_post_order_apply_duplicate_matmul(): + mod = DuplicateMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Duplicate Matmul Task", + sch_rules=[WowSoFancyScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + with pytest.raises( + TVMError, + match=r".*TVMError: Check failed: \(block_names_.count\(block->name_hint\) == 0\)" + r" is false: Duplicated block name matmul in function main not supported!", + ): + post_order_apply.generate_design_space(mod) + + +def test_meta_schedule_post_order_apply_remove_block(): + class TrinityDouble(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) + j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result = [new_sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) + j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result.append(new_sch) + return result + + class RemoveBlock(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + sch = sch.copy() + if sch.get(block).name_hint == "B": + sch.compute_inline(block) + return [sch] + + def correct_trace(a, b, c, d): + return "\n".join( + [ + 'b0 = sch.get_block(name="A", func_name="main")', + 'b1 = sch.get_block(name="B", func_name="main")', + 'b2 = sch.get_block(name="C", func_name="main")', + "sch.compute_inline(block=b1)", + 'b3 = sch.get_block(name="A", func_name="main")', + 'b4 = sch.get_block(name="C", func_name="main")', + "l5, l6 = sch.get_loops(block=b4)", + "l7, l8 = sch.split(loop=l5, factors=" + str(a) + ")", + "l9, l10 = sch.split(loop=l6, factors=" + str(b) + ")", + "sch.reorder(l7, l9, l8, l10)", + "l11, l12 = sch.get_loops(block=b3)", + "l13, l14 = sch.split(loop=l11, factors=" + str(c) + ")", + "l15, l16 = sch.split(loop=l12, factors=" + str(d) + ")", + "sch.reorder(l13, l15, l14, l16)", + ] + ) + + mod = TrinityMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Remove Block Task", + sch_rules=[RemoveBlock(), TrinityDouble()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 4 + for sch in schs: + with pytest.raises( + tvm.tir.schedule.schedule.ScheduleError, + match="ScheduleError: An error occurred in the schedule primitive 'get-block'.", + ): + sch.get_block("B", "main") + assert ( + str(sch.trace) == correct_trace([16, 64], [64, 16], [2, 512], [2, 512]) + or str(sch.trace) == correct_trace([2, 512], [2, 512], [2, 512], [2, 512]) + or str(sch.trace) == correct_trace([16, 64], [64, 16], [16, 64], [64, 16]) + or str(sch.trace) == correct_trace([2, 512], [2, 512], [16, 64], [64, 16]) + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py new file mode 100644 index 0000000000..52f07fdff0 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +import re + +import tvm +from tvm.script import tir as T + +from tvm.meta_schedule.postproc import PyPostproc +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address + +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def test_meta_schedule_postproc(): + class FancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + schedule_matmul(sch) + return True + + postproc = FancyPostproc() + mod = Matmul + sch = Schedule(mod) + assert postproc.apply(sch) + try: + tvm.ir.assert_structural_equal(sch.mod, mod) + raise Exception("The post processing did not change the schedule.") + except (ValueError): + _check_correct(sch) + + +def test_meta_schedule_postproc_fail(): + class FailingPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + return False + + postproc = FailingPostproc() + sch = Schedule(Matmul) + assert not postproc.apply(sch) + + +def test_meta_schedule_postproc_as_string(): + class NotSoFancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + pass + + def __str__(self) -> str: + return f"NotSoFancyPostproc({_get_hex_address(self.handle)})" + + postproc = NotSoFancyPostproc() + pattern = re.compile(r"NotSoFancyPostproc\(0x[a-f|0-9]*\)") + assert pattern.match(str(postproc)) + + +if __name__ == "__main__": + test_meta_schedule_postproc() + test_meta_schedule_postproc_fail() + test_meta_schedule_postproc_as_string() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule.py new file mode 100644 index 0000000000..e79ca69ca6 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List + +import math +import re + +import tvm +from tvm.script import tir as T + +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address + +from tvm.tir.schedule import Schedule, BlockRV + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def test_meta_schedule_schedule_rule(): + class FancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + return [sch] + + sch_rule = FancyScheduleRule() + mod = Matmul + sch = Schedule(mod) + res = sch_rule.apply(sch, block=sch.get_block("matmul")) + assert len(res) == 1 + try: + tvm.ir.assert_structural_equal(mod, res[0].mod) + raise Exception("The schedule rule did not change the schedule.") + except (ValueError): + _check_correct(res[0]) + + +def test_meta_schedule_schedule_rule_as_string(): + class YetStillSomeFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]: + pass + + def __str__(self) -> str: + return f"YetStillSomeFancyScheduleRule({_get_hex_address(self.handle)})" + + sch_rule = YetStillSomeFancyScheduleRule() + pattern = re.compile(r"YetStillSomeFancyScheduleRule\(0x[a-f|0-9]*\)") + assert pattern.match(str(sch_rule)) + + +if __name__ == "__main__": + test_meta_schedule_schedule_rule() + test_meta_schedule_schedule_rule_as_string() diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 9b3ddfd7c7..f940d11b79 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -26,7 +26,7 @@ from tvm.meta_schedule import TuneContext from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import ReplayTrace +from tvm.meta_schedule.search_strategy import SearchStrategy, ReplayTrace, ReplayFunc from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -56,9 +56,13 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument -def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: - trace_1 = Trace(sch_1.trace.insts, {}) - trace_2 = Trace(sch_2.trace.insts, {}) +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool: + if remove_decisions: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + else: + trace_1 = sch_1.trace + trace_2 = sch_2.trace return str(trace_1) == str(trace_2) @@ -72,29 +76,35 @@ def _schedule_matmul(sch: Schedule): sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) -def test_meta_schedule_replay_trace(): +@pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace]) +def test_meta_schedule_replay_func(TestClass: SearchStrategy): num_trials_per_iter = 7 num_trials_total = 20 - (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) - replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul) - replay.initialize_with_tune_context(tune_context) - - num_trials_each_round: List[int] = [] - replay.pre_tuning([example_sch]) - while True: - candidates = replay.generate_measure_candidates() - if candidates is None: - break - num_trials_each_round.append(len(candidates)) + strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: - assert _is_trace_equal(candidate.sch, example_sch) - runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) - replay.notify_runner_results(runner_results) - replay.post_tuning() - assert num_trials_each_round == [7, 7, 6] + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(type(strategy) == ReplayTrace), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + assert num_trials_each_iter == [7, 7, 6] if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 3f7749ca9e..8a674411fd 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -23,6 +23,9 @@ import pytest import tvm +from tvm._ffi.base import TVMError +from tvm.ir.module import IRModule +from tvm.meta_schedule.space_generator.space_generator import PySpaceGenerator from tvm.script import tir as T from tvm.tir.schedule import Schedule @@ -86,5 +89,10 @@ def test_meta_schedule_design_space_generator_union(): _check_correct(design_space) +def test_meta_schedule_design_space_generator_NIE(): + with pytest.raises(NotImplementedError): + PySpaceGenerator() + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_task_extraction.py b/tests/python/unittest/test_meta_schedule_task_extraction.py new file mode 100644 index 0000000000..e10a6dcc35 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_task_extraction.py @@ -0,0 +1,84 @@ +import sys +from typing import Dict, List, Tuple + +import pytest +import tvm +from tvm import meta_schedule as ms +from tvm.ir import IRModule +from tvm.meta_schedule.integration import ExtractedTask +from tvm.meta_schedule.testing import get_torch_model, MODEL_TYPE, MODEL_TYPES +from tvm.runtime import NDArray + + +@pytest.mark.skip("Skip because it runs too slowly as a unittest") +@pytest.mark.parametrize( + "model_name", + [ + # Image classification + "resnet50", + "alexnet", + "vgg16", + "squeezenet1_0", + "densenet121", + "densenet161", + "densenet169", + "densenet201", + "inception_v3", + "googlenet", + "shufflenet_v2_x1_0", + "mobilenet_v2", + "mobilenet_v3_large", + "mobilenet_v3_small", + "resnext50_32x4d", + "wide_resnet50_2", + "mnasnet1_0", + # Segmentation + "fcn_resnet50", + "fcn_resnet101", + "deeplabv3_resnet50", + "deeplabv3_resnet101", + "deeplabv3_mobilenet_v3_large", + "lraspp_mobilenet_v3_large", + # Object detection + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", + "maskrcnn_resnet50_fpn", + # video classification + "r3d_18", + "mc3_18", + "r2plus1d_18", + ], +) +@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("target", ["llvm", "cuda"]) +def test_meta_schedule_extract_from_torch_model(model_name: str, batch_size: int, target: str): + if model_name == "inception_v3" and batch_size == 1: + pytest.skip("inception_v3 does not handle batch_size of 1") + + input_shape: Tuple[int, ...] + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + input_shape = (1, 3, 300, 300) + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + input_shape = (batch_size, 3, 3, 299, 299) + else: + raise ValueError("Unsupported model: " + model_name) + + output_shape: Tuple[int, int] = (batch_size, 1000) + + mod, params = get_torch_model( + model_name=model_name, + input_shape=input_shape, + output_shape=output_shape, + dtype="float32", + ) + + extracted_tasks = ms.integration.extract_task(mod, params, target=target) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 4854aeb5f5..8a8f16c389 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -24,17 +24,18 @@ import pytest import tvm +from tvm._ffi.base import TVMError from tvm.script import tir as T from tvm.ir import IRModule -from tvm.tir import Schedule +from tvm.tir import Schedule, schedule from tvm.meta_schedule import TuneContext from tvm.meta_schedule.space_generator import ScheduleFn from tvm.meta_schedule.search_strategy import ReplayTrace from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.task_scheduler import RoundRobin -from tvm.meta_schedule.utils import structural_hash +from tvm.meta_schedule.task_scheduler import RoundRobin, PyTaskScheduler +from tvm.tir.expr import Not # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -224,5 +225,77 @@ def test_meta_schedule_task_scheduler_multiple(): assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total +def test_meta_schedule_task_scheduler_NIE(): + class MyTaskScheduler(PyTaskScheduler): + pass + + with pytest.raises(NotImplementedError): + MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase()) + + +def test_meta_schedule_task_scheduler_override_next_task_id_only(): + class MyTaskScheduler(PyTaskScheduler): + done = set() + + def next_task_id(self) -> int: + while len(self.done) != len(tasks): + x = random.randint(0, len(tasks) - 1) + task = tasks[x] + if not task.is_stopped: + """Calling base func via following route: + Python side: + PyTaskScheduler does not have `_is_task_running` + Call TaskScheduler's `is_task_running`, which calls ffi + C++ side: + The ffi calls TaskScheduler's `is_task_running` + But it is overrided in PyTaskScheduler + PyTaskScheduler checks if the function is overrided in python + If not, it returns the TaskScheduler's vtable, calling + TaskScheduler::IsTaskRunning + """ + if self._is_task_running(x): + # Same Here + self._join_running_task(x) + return x + else: + self.done.add(x) + return -1 + + num_trials_per_iter = 6 + num_trials_total = 101 + tasks = [ + TuneContext( + MatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="Matmul", + rand_state=42, + ), + TuneContext( + MatmulReluModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="MatmulRelu", + rand_state=0xDEADBEEF, + ), + TuneContext( + BatchMatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="BatchMatmul", + rand_state=0x114514, + ), + ] + database = DummyDatabase() + scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database) + scheduler.tune() + assert len(database) == num_trials_total * len(tasks) + for task in tasks: + assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_read_write_at.py b/tests/python/unittest/test_tir_schedule_read_write_at.py new file mode 100644 index 0000000000..79a7aad10f --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_read_write_at.py @@ -0,0 +1,221 @@ +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable + +@T.prim_func +def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in T.serial(0, 256): + for k1 in T.unroll(0, 8): + for _, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + +@T.prim_func +def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + C_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C_shared[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C_shared[vi, vj]]) + with T.init(): + C_shared[vi, vj] = T.float32(0) + C_shared[vi, vj] = C_shared[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + with T.block("C_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(32, bx) + T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 64): + C[v0 * 64 + ax0, v1 * 64 + ax1] = C_shared[v0 * 64 + ax0, v1 * 64 + ax1] + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable +# fmt: on + + +def test_read_at_global_to_shared_a(): + sch = tir.Schedule(cuda_matmul, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 1, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a) + verify_trace_roundtrip(sch, cuda_matmul) + + +def test_read_at_global_to_shared_ab(): + sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 2, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab) + verify_trace_roundtrip(sch, cuda_matmul_read_at_a) + + +def test_read_at_local_to_shared_c(): + sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.write_at(tx, block, 0, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c) + verify_trace_roundtrip(sch, cuda_matmul_read_at_ab) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index fbf0a6a5bd..5d2676e41d 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -14,15 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys from collections import defaultdict +import sys import pytest -import tvm + from tvm import tir from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip -from tvm.tir.schedule import Trace # pylint: disable=no-member,invalid-name,unused-variable @@ -30,9 +29,9 @@ @T.prim_func def elementwise(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128, 128)) - B = T.match_buffer(b, (128, 128, 128)) - for i, j, k in T.grid(128, 128, 128): + A = T.match_buffer(a, (128, 257, 1470)) + B = T.match_buffer(b, (128, 257, 1470)) + for i, j, k in T.grid(128, 257, 1470): with T.block("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -42,7 +41,7 @@ def elementwise(a: T.handle, b: T.handle) -> None: def test_sample_categorical(): - """Test sample categprical sampling function""" + """Test sample categorical sampling function""" n = 1000 sch = tir.Schedule(elementwise, seed=42, debug_mask="all") counter = defaultdict(int) @@ -87,5 +86,35 @@ def test_sample_categorical_serialize(): assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] +def test_sample_perfect_tile_power_of_two(): + sch = tir.Schedule(elementwise, debug_mask="all") + i, _, _ = sch.get_loops(sch.get_block("B")) + factors = sch.sample_perfect_tile(i, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 128 + verify_trace_roundtrip(sch, mod=elementwise) + + +def test_sample_perfect_tile_prime(): + sch = tir.Schedule(elementwise, debug_mask="all") + _, i, _ = sch.get_loops(sch.get_block("B")) + factors = sch.sample_perfect_tile(i, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 257 + verify_trace_roundtrip(sch, mod=elementwise) + + +def test_sample_perfect_tile_composite(): + sch = tir.Schedule(elementwise, debug_mask="all") + _, _, i = sch.get_loops(sch.get_block("B")) + factors = sch.sample_perfect_tile(i, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 1470 + verify_trace_roundtrip(sch, mod=elementwise) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 440d0ab67a..1596d08a1f 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -142,5 +142,22 @@ def test_tir_schedule_remove_rv(): sch.get(block_rv) +def test_get_child_blocks(): + s = tir.Schedule(matmul, debug_mask="all") + init = s.get_block("init") + update = s.get_block("update") + # loop + blocks = s.get_child_blocks(s.get_loops(init)[0]) + assert len(blocks) == 2 + assert s.get(init) == s.get(blocks[0]) + assert s.get(update) == s.get(blocks[1]) + # block + root = s.get_block("root") + blocks = s.get_child_blocks(root) + assert len(blocks) == 2 + assert s.get(init) == s.get(blocks[0]) + assert s.get(update) == s.get(blocks[1]) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))