From 8c96e90e34bf33d6ff8ff70555916532b9f00f12 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 8 Nov 2021 16:28:29 -0800 Subject: [PATCH 1/9] Fix sttr func & schedule naming. --- include/tvm/meta_schedule/postproc.h | 8 ++++---- include/tvm/meta_schedule/schedule_rule.h | 10 +++++----- python/tvm/meta_schedule/postproc/postproc.py | 10 +++++----- .../tvm/meta_schedule/schedule_rule/schedule_rule.py | 6 +++--- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 058a2afe27..a4baf4d36d 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -48,7 +48,7 @@ class PostprocNode : public runtime::Object { /*! * \brief Apply a post processing to the given schedule. - * \param sch The schedule to be post processed. + * \param schedule The schedule to be post processed. * \return Whether the post processing was successfully applied. */ virtual bool Apply(const tir::Schedule& sch) = 0; @@ -67,7 +67,7 @@ class PyPostprocNode : public PostprocNode { using FInitializeWithTuneContext = runtime::TypedPackedFunc; /*! * \brief Apply a post processing to the given schedule. - * \param sch The schedule to be post processed. + * \param schedule The schedule to be post processed. * \return Whether the post processing was successfully applied. */ using FApply = runtime::TypedPackedFunc; @@ -96,9 +96,9 @@ class PyPostprocNode : public PostprocNode { this->f_initialize_with_tune_context(context); } - bool Apply(const tir::Schedule& sch) final { + bool Apply(const tir::Schedule& schedule) final { ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; - return this->f_apply(sch); + return this->f_apply(schedule); } static constexpr const char* _type_key = "meta_schedule.PyPostproc"; diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index b9e6d87774..450bcd8f95 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -43,11 +43,11 @@ class ScheduleRuleNode : public runtime::Object { /*! * \brief Apply a schedule rule to the specific block in the given schedule. - * \param sch The schedule to be modified. + * \param schedule The schedule to be modified. * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ - virtual runtime::Array Apply(const tir::Schedule& sch, + virtual runtime::Array Apply(const tir::Schedule& schedule, const tir::BlockRV& block) = 0; static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; @@ -64,7 +64,7 @@ class PyScheduleRuleNode : public ScheduleRuleNode { using FInitializeWithTuneContext = runtime::TypedPackedFunc; /*! * \brief The function type of `Apply` method. - * \param sch The schedule to be modified. + * \param schedule The schedule to be modified. * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ @@ -95,9 +95,9 @@ class PyScheduleRuleNode : public ScheduleRuleNode { this->f_initialize_with_tune_context(context); } - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final { + Array Apply(const tir::Schedule& schedule, const tir::BlockRV& block) final { ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; - return this->f_apply(sch, block); + return this->f_apply(schedule, block); } static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index e05cc9a527..38596359ce 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -52,12 +52,12 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: self, tune_context ) - def apply(self, sch: Schedule) -> bool: + def apply(self, schedule: Schedule) -> bool: """Apply a post processing to the given schedule. Parameters ---------- - sch : Schedule + schedule : Schedule The schedule to be post processed. Returns @@ -65,7 +65,7 @@ def apply(self, sch: Schedule) -> bool: result : bool Whether the post processing was successfully applied. """ - return _ffi_api.PostprocApply(self, sch) + return _ffi_api.PostprocApply(self, schedule) @register_object("meta_schedule.PyPostproc") @@ -80,8 +80,8 @@ def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) @check_override(self.__class__, Postproc) - def f_apply(sch: Schedule) -> bool: - return self.apply(sch) + def f_apply(schedule: Schedule) -> bool: + return self.apply(schedule) def f_as_string() -> str: return str(self) diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index b995e5acb6..7994502dbf 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -52,7 +52,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: Parameters ---------- - sch : Schedule + schedule : Schedule The schedule to be modified. block : BlockRV The specific block to apply the schedule rule. @@ -79,8 +79,8 @@ def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) @check_override(self.__class__, ScheduleRule) - def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]: - return self.apply(sch, block) + def f_apply(schedule: Schedule, block: BlockRV) -> List[Schedule]: + return self.apply(schedule, block) def f_as_string() -> str: return self.__str__() From 61b65eeb5022e347321908ea2e8d9838688b5323 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 8 Nov 2021 21:56:37 -0800 Subject: [PATCH 2/9] Fix schedule -> sch. --- include/tvm/meta_schedule/postproc.h | 8 ++++---- include/tvm/meta_schedule/schedule_rule.h | 10 +++++----- python/tvm/meta_schedule/postproc/postproc.py | 10 +++++----- .../tvm/meta_schedule/schedule_rule/schedule_rule.py | 6 +++--- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index a4baf4d36d..058a2afe27 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -48,7 +48,7 @@ class PostprocNode : public runtime::Object { /*! * \brief Apply a post processing to the given schedule. - * \param schedule The schedule to be post processed. + * \param sch The schedule to be post processed. * \return Whether the post processing was successfully applied. */ virtual bool Apply(const tir::Schedule& sch) = 0; @@ -67,7 +67,7 @@ class PyPostprocNode : public PostprocNode { using FInitializeWithTuneContext = runtime::TypedPackedFunc; /*! * \brief Apply a post processing to the given schedule. - * \param schedule The schedule to be post processed. + * \param sch The schedule to be post processed. * \return Whether the post processing was successfully applied. */ using FApply = runtime::TypedPackedFunc; @@ -96,9 +96,9 @@ class PyPostprocNode : public PostprocNode { this->f_initialize_with_tune_context(context); } - bool Apply(const tir::Schedule& schedule) final { + bool Apply(const tir::Schedule& sch) final { ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; - return this->f_apply(schedule); + return this->f_apply(sch); } static constexpr const char* _type_key = "meta_schedule.PyPostproc"; diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 450bcd8f95..b9e6d87774 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -43,11 +43,11 @@ class ScheduleRuleNode : public runtime::Object { /*! * \brief Apply a schedule rule to the specific block in the given schedule. - * \param schedule The schedule to be modified. + * \param sch The schedule to be modified. * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ - virtual runtime::Array Apply(const tir::Schedule& schedule, + virtual runtime::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0; static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; @@ -64,7 +64,7 @@ class PyScheduleRuleNode : public ScheduleRuleNode { using FInitializeWithTuneContext = runtime::TypedPackedFunc; /*! * \brief The function type of `Apply` method. - * \param schedule The schedule to be modified. + * \param sch The schedule to be modified. * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ @@ -95,9 +95,9 @@ class PyScheduleRuleNode : public ScheduleRuleNode { this->f_initialize_with_tune_context(context); } - Array Apply(const tir::Schedule& schedule, const tir::BlockRV& block) final { + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final { ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; - return this->f_apply(schedule, block); + return this->f_apply(sch, block); } static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index 38596359ce..e05cc9a527 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -52,12 +52,12 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: self, tune_context ) - def apply(self, schedule: Schedule) -> bool: + def apply(self, sch: Schedule) -> bool: """Apply a post processing to the given schedule. Parameters ---------- - schedule : Schedule + sch : Schedule The schedule to be post processed. Returns @@ -65,7 +65,7 @@ def apply(self, schedule: Schedule) -> bool: result : bool Whether the post processing was successfully applied. """ - return _ffi_api.PostprocApply(self, schedule) + return _ffi_api.PostprocApply(self, sch) @register_object("meta_schedule.PyPostproc") @@ -80,8 +80,8 @@ def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) @check_override(self.__class__, Postproc) - def f_apply(schedule: Schedule) -> bool: - return self.apply(schedule) + def f_apply(sch: Schedule) -> bool: + return self.apply(sch) def f_as_string() -> str: return str(self) diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 7994502dbf..b995e5acb6 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -52,7 +52,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: Parameters ---------- - schedule : Schedule + sch : Schedule The schedule to be modified. block : BlockRV The specific block to apply the schedule rule. @@ -79,8 +79,8 @@ def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) @check_override(self.__class__, ScheduleRule) - def f_apply(schedule: Schedule, block: BlockRV) -> List[Schedule]: - return self.apply(schedule, block) + def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]: + return self.apply(sch, block) def f_as_string() -> str: return self.__str__() From a3d9440e6cda6c344730d0b9f58383268deb7377 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 9 Nov 2021 15:11:00 -0800 Subject: [PATCH 3/9] Add feature extractor. --- include/tvm/meta_schedule/feature_extractor.h | 110 ++++++++++++++++++ include/tvm/meta_schedule/mutator.h | 4 +- include/tvm/meta_schedule/postproc.h | 4 +- include/tvm/meta_schedule/schedule_rule.h | 4 +- .../feature_extractor/__init__.py | 21 ++++ .../feature_extractor/feature_extractor.py | 78 +++++++++++++ python/tvm/meta_schedule/postproc/__init__.py | 5 +- .../feature_extractor/feature_extractor.cc | 51 ++++++++ src/meta_schedule/utils.h | 1 + .../test_meta_schedule_feature_extractor.py | 61 ++++++++++ 10 files changed, 333 insertions(+), 6 deletions(-) create mode 100644 include/tvm/meta_schedule/feature_extractor.h create mode 100644 python/tvm/meta_schedule/feature_extractor/__init__.py create mode 100644 python/tvm/meta_schedule/feature_extractor/feature_extractor.py create mode 100644 src/meta_schedule/feature_extractor/feature_extractor.cc create mode 100644 tests/python/unittest/test_meta_schedule_feature_extractor.py diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h new file mode 100644 index 0000000000..2695ae495d --- /dev/null +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ +#define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ + +#include +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Extractor for features from measure candidates for use in cost model. */ +class FeatureExtractorNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~FeatureExtractorNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Extract features from the given measure candidate. + * \param tune_context The tuning context for feature extraction. + * \param candidates The measure candidates to extract features from. + * \return The feature ndarray extracted. + */ + virtual Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) = 0; + + static constexpr const char* _type_key = "meta_schedule.FeatureExtractor"; + TVM_DECLARE_BASE_OBJECT_INFO(FeatureExtractorNode, Object); +}; + +/*! \brief The feature extractor with customized methods on the python-side. */ +class PyFeatureExtractorNode : public FeatureExtractorNode { + public: + /*! + * \brief Extract features from the given measure candidate. + * \param tune_context The tuning context for feature extraction. + * \param candidates The measure candidates to extract features from. + * \return The feature ndarray extracted. + */ + using FExtractFrom = runtime::TypedPackedFunc( + const TuneContext& tune_context, const Array& candidates)>; + /*! + * \brief Get the feature extractor as string with name. + * \return The string of the feature extractor. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `ExtractFrom` funcion. */ + FExtractFrom f_extract_from; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_extract_from` is not visited + // `f_as_string` is not visited + } + + Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) { + ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; + return f_extract_from(tune_context, candidates); + } + + static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode); +}; + +/*! + * \brief Managed reference to FeatureExtractorNode + * \sa FeatureExtractorNode + */ +class FeatureExtractor : public runtime::ObjectRef { + public: + /*! + * \brief Create a feature extractor with customized methods on the python-side. + * \param f_extract_from The packed function of `ExtractFrom`. + * \param f_as_string The packed function of `AsString`. + * \return The feature extractor created. + */ + TVM_DLL static FeatureExtractor PyFeatureExtractor( + PyFeatureExtractorNode::FExtractFrom f_extract_from, // + PyFeatureExtractorNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FeatureExtractor, ObjectRef, FeatureExtractorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 82f5b76834..ffeacefd75 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -36,8 +36,9 @@ class MutatorNode : public runtime::Object { void VisitAttrs(tvm::AttrVisitor* v) {} /*! - * \brief The function type of `InitializeWithTuneContext` method. + * \brief Initialize the design space generator with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& context) = 0; @@ -110,6 +111,7 @@ class Mutator : public runtime::ObjectRef { * \brief Create a mutator with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. * \return The mutator created. */ TVM_DLL static Mutator PyMutator( diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 058a2afe27..0a50b0ca15 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -41,8 +41,9 @@ class PostprocNode : public runtime::Object { void VisitAttrs(tvm::AttrVisitor* v) {} /*! - * \brief The function type of `InitializeWithTuneContext` method. + * \brief Initialize the design space generator with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& context) = 0; @@ -115,6 +116,7 @@ class Postproc : public runtime::ObjectRef { * \brief Create a post processing with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. * \return The post processing created. */ TVM_DLL static Postproc PyPostproc( diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index b9e6d87774..9c0eaa8088 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -36,8 +36,9 @@ class ScheduleRuleNode : public runtime::Object { void VisitAttrs(tvm::AttrVisitor* v) {} /*! - * \brief The function type of `InitializeWithTuneContext` method. + * \brief Initialize the design space generator with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& context) = 0; @@ -155,6 +156,7 @@ class ScheduleRule : public runtime::ObjectRef { * \brief Create a schedule rule with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. * \return The schedule rule created. */ TVM_DLL static ScheduleRule PyScheduleRule( diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py b/python/tvm/meta_schedule/feature_extractor/__init__.py new file mode 100644 index 0000000000..ddc2e760e3 --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.feature_extractor package. +Meta Schedule feature extractors that +""" +from .feature_extractor import FeatureExtractor, PyFeatureExtractor diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py new file mode 100644 index 0000000000..5a6e75a122 --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule FeatureExtractor.""" + +from typing import List + +from tvm._ffi import register_object +from tvm.runtime import Object, NDArray +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate + + +@register_object("meta_schedule.FeatureExtractor") +class FeatureExtractor(Object): + """Extractor for features from measure candidates for use in cost model.""" + + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + """Extract features from the given measure candidate. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for feature extraction. + candidates : List[MeasureCandidate] + The measure candidates to extract features from. + + Returns + ------- + features : List[NDArray] + The feature ndarray extracted. + """ + return _ffi_api.FeatureExtractorExtractFrom(self, tune_context, candidates) + + +@register_object("meta_schedule.PyFeatureExtractor") +class PyFeatureExtractor(FeatureExtractor): + """An abstract feature extractor with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, FeatureExtractor) + def f_extract_from( + tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + return self.extract_from(tune_context, candidates) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.FeatureExtractorPyFeatureExtractor, # type: ignore # pylint: disable=no-member + f_extract_from, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 5316eb4663..e941729254 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -16,8 +16,7 @@ # under the License. """ The tvm.meta_schedule.postproc package. -Meta Schedule post processings that deal with the problem of -undertermined schedule validity after applying some schedule -primitves at runtime. +Meta Schedule post processings that extracts features from +measure candidates for use in cost model. """ from .postproc import Postproc, PyPostproc diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc new file mode 100644 index 0000000000..84d22493aa --- /dev/null +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +FeatureExtractor FeatureExtractor::PyFeatureExtractor( + PyFeatureExtractorNode::FExtractFrom f_extract_from, // + PyFeatureExtractorNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_extract_from = std::move(f_extract_from); + n->f_as_string = std::move(f_as_string); + return FeatureExtractor(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyFeatureExtractorNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyFeatureExtractor's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode); +TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom") + .set_body_method(&FeatureExtractorNode::ExtractFrom); +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor") + .set_body_typed(FeatureExtractor::PyFeatureExtractor); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index d8e96d0156..00bf3e48ca 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py new file mode 100644 index 0000000000..d27e3ce212 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List +import numpy as np +import re + +import tvm +from tvm.auto_scheduler import feature +from tvm.script import tir as T +from tvm.runtime import NDArray +from tvm.runtime.ndarray import numpyasarray +from tvm.tir.schedule import Schedule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.feature_extractor import PyFeatureExtractor + + +def test_meta_schedule_feature_extractor(): + class FancyFeatureExtractor(PyFeatureExtractor): + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + return [np.random.rand(4, 5)] + + extractor = FancyFeatureExtractor() + features = extractor.extract_from(TuneContext(), []) + assert len(features) == 1 + assert features[0].shape == (4, 5) + + +def test_meta_schedule_feature_extractor_as_string(): + class NotSoFancyFeatureExtractor(PyFeatureExtractor): + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + return None + + feature_extractor = NotSoFancyFeatureExtractor() + pattern = re.compile(r"NotSoFancyFeatureExtractor\(0x[a-f|0-9]*\)") + assert pattern.match(str(feature_extractor)) + + +if __name__ == "__main__": + test_meta_schedule_feature_extractor() + test_meta_schedule_feature_extractor_as_string() From d22b371ff40d29aa91f7a27a9b1b881dae1b5b05 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 9 Nov 2021 15:44:05 -0800 Subject: [PATCH 4/9] Fix init. --- python/tvm/meta_schedule/feature_extractor/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py b/python/tvm/meta_schedule/feature_extractor/__init__.py index ddc2e760e3..49310decf3 100644 --- a/python/tvm/meta_schedule/feature_extractor/__init__.py +++ b/python/tvm/meta_schedule/feature_extractor/__init__.py @@ -16,6 +16,7 @@ # under the License. """ The tvm.meta_schedule.feature_extractor package. -Meta Schedule feature extractors that +Meta Schedule feature extractors that extracts features from +measure candidates for use in cost model. """ from .feature_extractor import FeatureExtractor, PyFeatureExtractor From 89d7d61eed46509b1bcbde98d97a8830c4b462a0 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 9 Nov 2021 15:45:04 -0800 Subject: [PATCH 5/9] Revert wrong description. --- python/tvm/meta_schedule/postproc/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index e941729254..5316eb4663 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -16,7 +16,8 @@ # under the License. """ The tvm.meta_schedule.postproc package. -Meta Schedule post processings that extracts features from -measure candidates for use in cost model. +Meta Schedule post processings that deal with the problem of +undertermined schedule validity after applying some schedule +primitves at runtime. """ from .postproc import Postproc, PyPostproc From 5ed22ed4a5364c70ec298f7fbd402e6e07d68a76 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 10 Nov 2021 15:29:35 -0800 Subject: [PATCH 6/9] Add cost model. --- include/tvm/meta_schedule/cost_model.h | 187 ++++++++++++++++++ .../tvm/meta_schedule/cost_model/__init__.py | 22 +++ .../meta_schedule/cost_model/cost_model.py | 163 +++++++++++++++ src/meta_schedule/cost_model/cost_model.cc | 66 +++++++ src/meta_schedule/utils.h | 1 + .../unittest/test_meta_schedule_cost_model.py | 111 +++++++++++ .../test_meta_schedule_feature_extractor.py | 7 +- 7 files changed, 551 insertions(+), 6 deletions(-) create mode 100644 include/tvm/meta_schedule/cost_model.h create mode 100644 python/tvm/meta_schedule/cost_model/__init__.py create mode 100644 python/tvm/meta_schedule/cost_model/cost_model.py create mode 100644 src/meta_schedule/cost_model/cost_model.cc create mode 100644 tests/python/unittest/test_meta_schedule_cost_model.py diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h new file mode 100644 index 0000000000..cb0eade17e --- /dev/null +++ b/include/tvm/meta_schedule/cost_model.h @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_COST_MODEL_H_ +#define TVM_META_SCHEDULE_COST_MODEL_H_ + +#include +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Cost model for estimation of running time, thus reducing search space. */ +class CostModelNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~CostModelNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Load the cost model from given file location. + * \param file_location The file location. + * \return Whether cost model was loaded successfully. + */ + virtual bool Load(const String& file_location) = 0; + + /*! + * \brief Save the cost model to given file location. + * \param file_location The file location. + * \return Whether cost model was saved successfully. + */ + virtual bool Save(const String& file_location) = 0; + + /*! + * \brief Update the cost model given running results. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param results The running results of the measure candidates. + * \return Whether cost model was updated successfully. + */ + virtual bool Update(const TuneContext& tune_context, const Array& candidates, + const Array& results) = 0; + + /*! + * \brief Predict the running results of given measure candidates. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \return The predicted running results. + */ + virtual std::vector Predict(const TuneContext& tune_context, + const Array& candidates) = 0; + + static constexpr const char* _type_key = "meta_schedule.CostModel"; + TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); +}; + +/*! \brief The cost model with customized methods on the python-side. */ +class PyCostModelNode : public CostModelNode { + public: + /*! + * \brief Load the cost model from given file location. + * \param file_location The file location. + * \return Whether cost model was loaded successfully. + */ + using FLoad = runtime::TypedPackedFunc; + /*! + * \brief Save the cost model to given file location. + * \param file_location The file location. + * \return Whether cost model was saved successfully. + */ + using FSave = runtime::TypedPackedFunc; + /*! + * \brief Update the cost model given running results. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param results The running results of the measure candidates. + * \return Whether cost model was updated successfully. + */ + using FUpdate = runtime::TypedPackedFunc&, + const Array&)>; + /*! + * \brief Predict the running results of given measure candidates. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param p_addr The address to save the the estimated running results. + */ + using FPredict = runtime::TypedPackedFunc&, + void* p_addr)>; + /*! + * \brief Get the cost model as string with name. + * \return The string representation of the cost model. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Load` funcion. */ + FLoad f_load; + /*! \brief The packed function to the `Save` funcion. */ + FSave f_save; + /*! \brief The packed function to the `Update` funcion. */ + FUpdate f_update; + /*! \brief The packed function to the `Predict` funcion. */ + FPredict f_predict; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_load` is not visited + // `f_save` is not visited + // `f_update` is not visited + // `f_predict` is not visited + // `f_as_string` is not visited + } + + bool Load(const String& file_location) { + ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; + return f_load(file_location); + } + + bool Save(const String& file_location) { + ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; + return f_save(file_location); + } + + bool Update(const TuneContext& tune_context, const Array& candidates, + const Array& results) { + ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; + return f_update(tune_context, candidates, results); + } + + std::vector Predict(const TuneContext& tune_context, + const Array& candidates) { + ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; + std::vector result(candidates.size(), 0.0); + f_predict(tune_context, candidates, result.data()); + return result; + } + + static constexpr const char* _type_key = "meta_schedule.PyCostModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode); +}; + +/*! + * \brief Managed reference to CostModelNode + * \sa CostModelNode + */ +class CostModel : public runtime::ObjectRef { + public: + /*! + * \brief Create a feature extractor with customized methods on the python-side. + * \param f_load The packed function of `Load`. + * \param f_save The packed function of `Save`. + * \param f_update The packed function of `Update`. + * \param f_predict The packed function of `Predict`. + * \param f_as_string The packed function of `AsString`. + * \return The feature extractor created. + */ + TVM_DLL static CostModel PyCostModel(PyCostModelNode::FLoad f_load, // + PyCostModelNode::FSave f_save, // + PyCostModelNode::FUpdate f_update, // + PyCostModelNode::FPredict f_predict, // + PyCostModelNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_COST_MODEL_H_ diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py new file mode 100644 index 0000000000..1784953826 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.cost_model package. +Meta Schedule cost model for estimation of running time, thus +reducing search space. +""" +from .cost_model import CostModel, PyCostModel diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py new file mode 100644 index 0000000000..f24a05747c --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule CostModel.""" + +from typing import List, TYPE_CHECKING +import ctypes + +import numpy as np + +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api +from ..runner import RunnerResult +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..utils import _get_hex_address, check_override + + +@register_object("meta_schedule.CostModel") +class CostModel(Object): + """Cost model for estimation of running time, thus reducing search space.""" + + def load(self, file_location: str) -> bool: + """Load the cost model from given file location. + + Parameters + ---------- + file_location : str + The file location. + + Return + ------ + result : bool + Whether cost model was loaded successfully. + """ + _ffi_api.CostModelLoad(self, file_location) # type: ignore # pylint: disable=no-member + + def save(self, file_location: str) -> bool: + """Save the cost model to given file location. + + Parameters + ---------- + file_location : str + The file location. + + Return + ------ + result : bool + Whether cost model was saved successfully. + """ + _ffi_api.CostModelSave(self, file_location) # type: ignore # pylint: disable=no-member + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> bool: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + + Return + ------ + result : bool + Whether cost model was updated successfully. + """ + _ffi_api.CostModelUpdate(self, tune_context, candidates, results) # type: ignore # pylint: disable=no-member + + def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + + Return + ------ + result : bool + The predicted running results. + """ + n = len(candidates) + results = np.zeros(shape=(n,), dtype="float64") + _ffi_api.CostModelPredict( + self, + tune_context, + candidates, + results.ctypes.data_as(ctypes.c_void_p), + ) # type: ignore # pylint: disable=no-member + return results + + +@register_object("meta_schedule.PyCostModel") +class PyCostModel(CostModel): + """An abstract CostModel with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, CostModel) + def f_load(file_location: str) -> bool: + self.load(file_location) + + @check_override(self.__class__, CostModel) + def f_save(file_location: str) -> bool: + self.save(file_location) + + @check_override(self.__class__, CostModel) + def f_update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> bool: + self.update(tune_context, candidates, results) + + @check_override(self.__class__, CostModel) + def f_predict(tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr): + n = len(candidates) + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) + array_wrapper[:] = self.predict(tune_context, candidates) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.CostModelPyCostModel, # type: ignore # pylint: disable=no-member + f_load, + f_save, + f_update, + f_predict, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc new file mode 100644 index 0000000000..36f6a89d9a --- /dev/null +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, // + PyCostModelNode::FSave f_save, // + PyCostModelNode::FUpdate f_update, // + PyCostModelNode::FPredict f_predict, // + PyCostModelNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_load = std::move(f_load); + n->f_save = std::move(f_save); + n->f_update = std::move(f_update); + n->f_predict = std::move(f_predict); + n->f_as_string = std::move(f_as_string); + return CostModel(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyCostModelNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyCostModel's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(CostModelNode); +TVM_REGISTER_NODE_TYPE(PyCostModelNode); + +struct Internal { + static void CostModelPredict(CostModel model, const TuneContext& tune_context, + Array candidates, void* p_addr) { + std::vector result = model->Predict(tune_context, candidates); + std::copy(result.begin(), result.end(), static_cast(p_addr)); + } +}; + +TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate") + .set_body_method(&CostModelNode::Update); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict").set_body_typed(Internal::CostModelPredict); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 00bf3e48ca..9a4b33e6c2 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py new file mode 100644 index 0000000000..b662bdaaef --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List +import numpy as np +import re + +import tvm +from tvm.script import tir as T +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.cost_model import PyCostModel +from tvm.tir.schedule.schedule import Schedule + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +def test_meta_schedule_cost_model(): + class FancyCostModel(PyCostModel): + def load(self, file_location: str) -> bool: + return True + + def save(self, file_location: str) -> bool: + return True + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> bool: + return True + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + return np.random.rand(10, 12) + + model = FancyCostModel() + assert model.save("fancy_test_location") + assert model.load("fancy_test_location") + assert model.update(TuneContext(), [], []) + results = model.predict(TuneContext, [MeasureCandidate(Schedule(mod=Matmul), [])]) + assert len(results) == 1 + assert results[0].shape == (10, 12) + + +def test_meta_schedule_cost_model_as_string(): + class NotSoFancyCostModel(PyCostModel): + def load(self, file_location: str) -> bool: + return True + + def save(self, file_location: str) -> bool: + return True + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> bool: + return True + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + return np.random.rand(10, 12) + + cost_model = NotSoFancyCostModel() + pattern = re.compile(r"NotSoFancyCostModel\(0x[a-f|0-9]*\)") + assert pattern.match(str(cost_model)) + + +if __name__ == "__main__": + # test_meta_schedule_cost_model() + test_meta_schedule_cost_model_as_string() diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py index d27e3ce212..05b2bae40b 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py @@ -16,17 +16,12 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from typing import List + import numpy as np import re -import tvm -from tvm.auto_scheduler import feature -from tvm.script import tir as T from tvm.runtime import NDArray -from tvm.runtime.ndarray import numpyasarray -from tvm.tir.schedule import Schedule from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.utils import _get_hex_address from tvm.meta_schedule.search_strategy import MeasureCandidate from tvm.meta_schedule.feature_extractor import PyFeatureExtractor From 8a365bce36a7ede8a9bd609e8932b62c2bade66d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 10 Nov 2021 15:32:25 -0800 Subject: [PATCH 7/9] Remove unused include. --- include/tvm/meta_schedule/cost_model.h | 1 - include/tvm/meta_schedule/feature_extractor.h | 3 --- 2 files changed, 4 deletions(-) diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index cb0eade17e..0bc2293305 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -21,7 +21,6 @@ #define TVM_META_SCHEDULE_COST_MODEL_H_ #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index 2695ae495d..2c145184ab 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -20,9 +20,6 @@ #ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ #define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ -#include -#include - namespace tvm { namespace meta_schedule { From 406847e617a39bb9a27e75c49b779ba949a129da Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 12 Nov 2021 12:58:27 -0800 Subject: [PATCH 8/9] Fix issues. --- .../extend_tvm/bring_your_own_datatypes.py | 2 +- include/tvm/auto_scheduler/cost_model.h | 6 ++-- include/tvm/auto_scheduler/measure.h | 2 +- include/tvm/meta_schedule/cost_model.h | 21 ++++++------- include/tvm/meta_schedule/feature_extractor.h | 6 ++-- include/tvm/meta_schedule/measure_callback.h | 4 +-- include/tvm/meta_schedule/mutator.h | 6 ++-- include/tvm/meta_schedule/space_generator.h | 2 +- include/tvm/meta_schedule/task_scheduler.h | 4 +-- .../tvm/meta_schedule/cost_model/__init__.py | 2 -- .../meta_schedule/cost_model/cost_model.py | 31 +++++++++---------- .../feature_extractor/feature_extractor.py | 5 +-- src/meta_schedule/cost_model/cost_model.cc | 17 +++++----- src/runtime/library_module.h | 2 +- .../unittest/test_meta_schedule_cost_model.py | 14 ++++----- 15 files changed, 61 insertions(+), 63 deletions(-) diff --git a/gallery/how_to/extend_tvm/bring_your_own_datatypes.py b/gallery/how_to/extend_tvm/bring_your_own_datatypes.py index 1cf556ddd0..0182456099 100644 --- a/gallery/how_to/extend_tvm/bring_your_own_datatypes.py +++ b/gallery/how_to/extend_tvm/bring_your_own_datatypes.py @@ -313,7 +313,7 @@ def convert_ndarray(dst_dtype, array): print(str(e).split("\n")[-1]) ###################################################################### -# When we attempt to run the model, we get a familiar error telling us that more funcions need to be registerd for myfloat. +# When we attempt to run the model, we get a familiar error telling us that more functions need to be registerd for myfloat. # # Because this is a neural network, many more operations are required. # Here, we register all the needed functions: diff --git a/include/tvm/auto_scheduler/cost_model.h b/include/tvm/auto_scheduler/cost_model.h index f7a27895a7..a52c6797b6 100755 --- a/include/tvm/auto_scheduler/cost_model.h +++ b/include/tvm/auto_scheduler/cost_model.h @@ -122,11 +122,11 @@ class RandomModel : public CostModel { * This class will call functions defined in the python */ class PythonBasedModelNode : public CostModelNode { public: - /*! \brief Pointer to the update funcion in python */ + /*! \brief Pointer to the update function in python */ PackedFunc update_func; - /*! \brief Pointer to the predict funcion in python */ + /*! \brief Pointer to the predict function in python */ PackedFunc predict_func; - /*! \brief Pointer to the predict funcion in python */ + /*! \brief Pointer to the predict function in python */ PackedFunc predict_stage_func; void Update(const Array& inputs, const Array& results) final; diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h index 841b6b9530..20a93e280b 100755 --- a/include/tvm/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -236,7 +236,7 @@ class MeasureCallback : public ObjectRef { * This class will call functions defined in the python */ class PythonBasedMeasureCallbackNode : public MeasureCallbackNode { public: - /*! \brief Pointer to the callback funcion in python */ + /*! \brief Pointer to the callback function in python */ PackedFunc callback_func; void Callback(const SearchPolicy& policy, const Array& inputs, diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index 0bc2293305..dfc816bb00 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -27,7 +27,7 @@ namespace meta_schedule { class TuneContext; -/*! \brief Cost model for estimation of running time, thus reducing search space. */ +/*! \brief Cost model. */ class CostModelNode : public runtime::Object { public: /*! \brief Virtual destructor. */ @@ -54,9 +54,8 @@ class CostModelNode : public runtime::Object { * \param tune_context The tuning context. * \param candidates The measure candidates. * \param results The running results of the measure candidates. - * \return Whether cost model was updated successfully. */ - virtual bool Update(const TuneContext& tune_context, const Array& candidates, + virtual void Update(const TuneContext& tune_context, const Array& candidates, const Array& results) = 0; /*! @@ -94,7 +93,7 @@ class PyCostModelNode : public CostModelNode { * \param results The running results of the measure candidates. * \return Whether cost model was updated successfully. */ - using FUpdate = runtime::TypedPackedFunc&, + using FUpdate = runtime::TypedPackedFunc&, const Array&)>; /*! * \brief Predict the running results of given measure candidates. @@ -110,15 +109,15 @@ class PyCostModelNode : public CostModelNode { */ using FAsString = runtime::TypedPackedFunc; - /*! \brief The packed function to the `Load` funcion. */ + /*! \brief The packed function to the `Load` function. */ FLoad f_load; - /*! \brief The packed function to the `Save` funcion. */ + /*! \brief The packed function to the `Save` function. */ FSave f_save; - /*! \brief The packed function to the `Update` funcion. */ + /*! \brief The packed function to the `Update` function. */ FUpdate f_update; - /*! \brief The packed function to the `Predict` funcion. */ + /*! \brief The packed function to the `Predict` function. */ FPredict f_predict; - /*! \brief The packed function to the `AsString` funcion. */ + /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; void VisitAttrs(tvm::AttrVisitor* v) { @@ -139,10 +138,10 @@ class PyCostModelNode : public CostModelNode { return f_save(file_location); } - bool Update(const TuneContext& tune_context, const Array& candidates, + void Update(const TuneContext& tune_context, const Array& candidates, const Array& results) { ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; - return f_update(tune_context, candidates, results); + f_update(tune_context, candidates, results); } std::vector Predict(const TuneContext& tune_context, diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index 2c145184ab..30e2f0fe62 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -20,6 +20,8 @@ #ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ #define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ +#include + namespace tvm { namespace meta_schedule { @@ -63,9 +65,9 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { */ using FAsString = runtime::TypedPackedFunc; - /*! \brief The packed function to the `ExtractFrom` funcion. */ + /*! \brief The packed function to the `ExtractFrom` function. */ FExtractFrom f_extract_from; - /*! \brief The packed function to the `AsString` funcion. */ + /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index 9ee7039959..67152397dc 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -81,9 +81,9 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { */ using FAsString = runtime::TypedPackedFunc; - /*! \brief The packed function to the `Apply` funcion. */ + /*! \brief The packed function to the `Apply` function. */ FApply f_apply; - /*! \brief The packed function to the `AsString` funcion. */ + /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index ffeacefd75..e1c5ee3be9 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -73,11 +73,11 @@ class PyMutatorNode : public MutatorNode { */ using FAsString = runtime::TypedPackedFunc; - /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` funcion. */ + /*! \brief The packed function to the `Apply` function. */ FApply f_apply; - /*! \brief The packed function to the `AsString` funcion. */ + /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index a0dfede820..c0f37c6037 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -102,7 +102,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { */ using FGenerateDesignSpace = runtime::TypedPackedFunc(const IRModule&)>; - /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ FInitializeWithTuneContext f_initialize_with_tune_context; /*! \brief The packed function to the `GenerateDesignSpace` function. */ FGenerateDesignSpace f_generate_design_space; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index b8cfe7b381..f0c3a3c208 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -162,9 +162,9 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { */ using FNextTaskId = runtime::TypedPackedFunc; - /*! \brief The packed function to the `Tune` funcion. */ + /*! \brief The packed function to the `Tune` function. */ FTune f_tune; - /*! \brief The packed function to the `InitializeTask` funcion. */ + /*! \brief The packed function to the `InitializeTask` function. */ FInitializeTask f_initialize_task; /*! \brief The packed function to the `SetTaskStopped` function. */ FSetTaskStopped f_set_task_stopped; diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py index 1784953826..7267c5ae54 100644 --- a/python/tvm/meta_schedule/cost_model/__init__.py +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -16,7 +16,5 @@ # under the License. """ The tvm.meta_schedule.cost_model package. -Meta Schedule cost model for estimation of running time, thus -reducing search space. """ from .cost_model import CostModel, PyCostModel diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index f24a05747c..da29c7db66 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -16,7 +16,7 @@ # under the License. """Meta Schedule CostModel.""" -from typing import List, TYPE_CHECKING +from typing import List import ctypes import numpy as np @@ -33,7 +33,7 @@ @register_object("meta_schedule.CostModel") class CostModel(Object): - """Cost model for estimation of running time, thus reducing search space.""" + """Cost model.""" def load(self, file_location: str) -> bool: """Load the cost model from given file location. @@ -48,7 +48,7 @@ def load(self, file_location: str) -> bool: result : bool Whether cost model was loaded successfully. """ - _ffi_api.CostModelLoad(self, file_location) # type: ignore # pylint: disable=no-member + return bool(_ffi_api.CostModelLoad(self, file_location)) # type: ignore # pylint: disable=no-member def save(self, file_location: str) -> bool: """Save the cost model to given file location. @@ -63,14 +63,14 @@ def save(self, file_location: str) -> bool: result : bool Whether cost model was saved successfully. """ - _ffi_api.CostModelSave(self, file_location) # type: ignore # pylint: disable=no-member + return bool(_ffi_api.CostModelSave(self, file_location)) # type: ignore # pylint: disable=no-member def update( self, tune_context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], - ) -> bool: + ) -> None: """Update the cost model given running results. Parameters @@ -81,11 +81,6 @@ def update( The measure candidates. results : List[RunnerResult] The running results of the measure candidates. - - Return - ------ - result : bool - Whether cost model was updated successfully. """ _ffi_api.CostModelUpdate(self, tune_context, candidates, results) # type: ignore # pylint: disable=no-member @@ -99,7 +94,6 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) candidates : List[MeasureCandidate] The measure candidates. - Return ------ result : bool @@ -107,12 +101,12 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) """ n = len(candidates) results = np.zeros(shape=(n,), dtype="float64") - _ffi_api.CostModelPredict( + _ffi_api.CostModelPredict( # type: ignore # pylint: disable=no-member self, tune_context, candidates, results.ctypes.data_as(ctypes.c_void_p), - ) # type: ignore # pylint: disable=no-member + ) return results @@ -125,11 +119,11 @@ def __init__(self): @check_override(self.__class__, CostModel) def f_load(file_location: str) -> bool: - self.load(file_location) + return self.load(file_location) @check_override(self.__class__, CostModel) def f_save(file_location: str) -> bool: - self.save(file_location) + return self.save(file_location) @check_override(self.__class__, CostModel) def f_update( @@ -141,11 +135,16 @@ def f_update( self.update(tune_context, candidates, results) @check_override(self.__class__, CostModel) - def f_predict(tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr): + def f_predict( + tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr + ) -> None: n = len(candidates) return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double)) array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) array_wrapper[:] = self.predict(tune_context, candidates) + assert ( + array_wrapper.dtype == "float64" + ), "ValueError: Invalid data type returned from CostModel Predict!" def f_as_string() -> str: return str(self) diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py index 5a6e75a122..4126e0fb45 100644 --- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -20,7 +20,6 @@ from tvm._ffi import register_object from tvm.runtime import Object, NDArray -from tvm.tir.schedule import Schedule from .. import _ffi_api from ..utils import _get_hex_address, check_override @@ -49,7 +48,9 @@ def extract_from( features : List[NDArray] The feature ndarray extracted. """ - return _ffi_api.FeatureExtractorExtractFrom(self, tune_context, candidates) + return _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member + self, tune_context, candidates + ) @register_object("meta_schedule.PyFeatureExtractor") diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc index 36f6a89d9a..5cd32b097c 100644 --- a/src/meta_schedule/cost_model/cost_model.cc +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -47,19 +47,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(CostModelNode); TVM_REGISTER_NODE_TYPE(PyCostModelNode); -struct Internal { - static void CostModelPredict(CostModel model, const TuneContext& tune_context, - Array candidates, void* p_addr) { - std::vector result = model->Predict(tune_context, candidates); - std::copy(result.begin(), result.end(), static_cast(p_addr)); - } -}; - TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate") .set_body_method(&CostModelNode::Update); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict").set_body_typed(Internal::CostModelPredict); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict") + .set_body_typed([](CostModel model, // + const TuneContext& tune_context, // + Array candidates, // + void* p_addr) -> void { + std::vector result = model->Predict(tune_context, candidates); + std::copy(result.begin(), result.end(), static_cast(p_addr)); + }); TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel); } // namespace meta_schedule diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index b5780975f4..44dc323186 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -79,7 +79,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& void InitContextFunctions(std::function fgetsymbol); /*! - * \brief Type alias for funcion to wrap a TVMBackendPackedCFunc. + * \brief Type alias for function to wrap a TVMBackendPackedCFunc. * \param The function address imported from a module. * \param mptr The module pointer node. * \return Packed function that wraps the invocation of the function at faddr. diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index b662bdaaef..14253ff0cb 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -63,18 +63,18 @@ def update( tune_context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], - ) -> bool: - return True + ) -> None: + pass def predict( self, tune_context: TuneContext, candidates: List[MeasureCandidate] ) -> np.ndarray: - return np.random.rand(10, 12) + return [np.random.rand(10, 12)] model = FancyCostModel() assert model.save("fancy_test_location") assert model.load("fancy_test_location") - assert model.update(TuneContext(), [], []) + model.update(TuneContext(), [], []) results = model.predict(TuneContext, [MeasureCandidate(Schedule(mod=Matmul), [])]) assert len(results) == 1 assert results[0].shape == (10, 12) @@ -93,8 +93,8 @@ def update( tune_context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], - ) -> bool: - return True + ) -> None: + pass def predict( self, tune_context: TuneContext, candidates: List[MeasureCandidate] @@ -107,5 +107,5 @@ def predict( if __name__ == "__main__": - # test_meta_schedule_cost_model() + test_meta_schedule_cost_model() test_meta_schedule_cost_model_as_string() From 73dfaee6fa828f32ff2f4f8e6ee5fe5ae6abd9ed Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 12 Nov 2021 13:38:53 -0800 Subject: [PATCH 9/9] Fix init. --- python/tvm/meta_schedule/__init__.py | 2 ++ .../unittest/test_meta_schedule_cost_model.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index c57355a391..37e8ffa9d8 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -25,4 +25,6 @@ from . import space_generator from . import search_strategy from . import integration +from . import feature_extractor +from . import cost_model from .tune_context import TuneContext diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index 14253ff0cb..d1ce14e5a4 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -14,23 +14,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from typing import List -import numpy as np + import re +import numpy as np import tvm from tvm.script import tir as T from tvm.meta_schedule import TuneContext from tvm.meta_schedule.search_strategy import MeasureCandidate from tvm.meta_schedule.runner import RunnerResult -from tvm.meta_schedule.cost_model import PyCostModel from tvm.tir.schedule.schedule import Schedule +from tvm.meta_schedule.cost_model import PyCostModel # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring - - @tvm.script.ir_module class Matmul: @T.prim_func @@ -47,7 +45,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,disable=unused-argument def test_meta_schedule_cost_model(): @@ -69,15 +67,14 @@ def update( def predict( self, tune_context: TuneContext, candidates: List[MeasureCandidate] ) -> np.ndarray: - return [np.random.rand(10, 12)] + return np.random.rand(10) model = FancyCostModel() assert model.save("fancy_test_location") assert model.load("fancy_test_location") model.update(TuneContext(), [], []) results = model.predict(TuneContext, [MeasureCandidate(Schedule(mod=Matmul), [])]) - assert len(results) == 1 - assert results[0].shape == (10, 12) + assert results.shape == (10,) def test_meta_schedule_cost_model_as_string(): @@ -99,7 +96,7 @@ def update( def predict( self, tune_context: TuneContext, candidates: List[MeasureCandidate] ) -> np.ndarray: - return np.random.rand(10, 12) + return np.random.rand(10) cost_model = NotSoFancyCostModel() pattern = re.compile(r"NotSoFancyCostModel\(0x[a-f|0-9]*\)")