Skip to content

Commit

Permalink
[M3c][MetaScheduler] Add More Measure Callbacks. (#9780)
Browse files Browse the repository at this point in the history
* Add measure callbacks.

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

* Fix comments.

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
  • Loading branch information
7 people authored Dec 23, 2021
1 parent d026d06 commit b35fc83
Show file tree
Hide file tree
Showing 16 changed files with 1,108 additions and 0 deletions.
146 changes: 146 additions & 0 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_

#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/tune_context.h>

namespace tvm {
namespace meta_schedule {

class TaskScheduler;

/*! \brief Rules to apply after measure results is available. */
class MeasureCallbackNode : public runtime::Object {
public:
/*! \brief Virtual destructor. */
virtual ~MeasureCallbackNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief Apply a measure callback rule with given arguments.
* \param task_scheduler The task scheduler.
* \param task_id The id of the task (tune context) to apply measure callbacks.
* \param measure_candidates The measure candidates.
* \param builder_results The builder results by building the measure candidates.
* \param runner_results The runner results by running the built measure candidates.
*/
virtual void Apply(const TaskScheduler& task_scheduler, //
int task_id, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builder_results, //
const Array<RunnerResult>& runner_results) = 0;

static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
};

/*! \brief The measure callback with customized methods on the python-side. */
class PyMeasureCallbackNode : public MeasureCallbackNode {
public:
/*!
* \brief Apply a measure callback to the given schedule.
* \param task_scheduler The task scheduler.
* \param tasks The list of tune context to process.
* \param measure_candidates The measure candidates.
* \param builds The builder results by building the measure candidates.
* \param results The runner results by running the built measure candidates.
* \return Whether the measure callback was successfully applied.
*/
using FApply =
runtime::TypedPackedFunc<void(const TaskScheduler& task_scheduler, //
int task_id, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results)>;
/*!
* \brief Get the measure callback function as string with name.
* \return The string of the measure callback function.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_apply` is not visited
// `f_as_string` is not visited
}

void Apply(const TaskScheduler& task_scheduler, //
int task_id, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) final {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
return this->f_apply(task_scheduler, task_id, measure_candidates, builds, results);
}

static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
};

/*!
* \brief Managed reference to MeasureCallbackNode
* \sa MeasureCallbackNode
*/
class MeasureCallback : public runtime::ObjectRef {
public:
/*!
* \brief Create a measure callback that adds the measurement results into the database
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback AddToDatabase();
/*!
* \brief Create a measure callback that removes the build artifacts from the disk
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback RemoveBuildArtifact();
/*!
* \brief Create a measure callback that echos the statistics of the tuning process to the console
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback EchoStatistics();
/*!
* \brief Create a measure callback that updates the cost model with measurement result.
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback UpdateCostModel();
/*!
* \brief Create a measure callback with customized methods on the python-side.
* \param f_apply The packed function of `Apply`.
* \param f_as_string The packed function of `AsString`.
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply,
PyMeasureCallbackNode::FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
6 changes: 6 additions & 0 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class TaskSchedulerNode : public runtime::Object {
Runner runner{nullptr};
/*! \brief The database of the scheduler. */
Database database{nullptr};
/*! \brief The cost model of the scheduler. */
Optional<CostModel> cost_model;
/*! \brief The list of measure callbacks of the scheduler. */
Array<MeasureCallback> measure_callbacks;

/*! \brief The default desctructor. */
virtual ~TaskSchedulerNode() = default;
Expand All @@ -82,6 +86,8 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("database", &database);
v->Visit("cost_model", &cost_model);
v->Visit("measure_callbacks", &measure_callbacks);
}

/*! \brief Auto-tuning. */
Expand Down
1 change: 1 addition & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_

#include <tvm/ir/module.h>
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/support/random_engine.h>
#include <tvm/target/target.h>
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/meta_schedule/measure_callback/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
The tvm.meta_schedule.measure_callback package.
"""
from .measure_callback import MeasureCallback, PyMeasureCallback
from .add_to_database import AddToDatabase
from .echo_statistics import EchoStatistics
from .remove_build_artifact import RemoveBuildArtifact
from .update_cost_model import UpdateCostModel
30 changes: 30 additions & 0 deletions python/tvm/meta_schedule/measure_callback/add_to_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A callback that adds the measurement results into the database"""
from tvm._ffi import register_object

from .. import _ffi_api
from .measure_callback import MeasureCallback


@register_object("meta_schedule.AddToDatabase")
class AddToDatabase(MeasureCallback):
def __init__(self) -> None:
"""A callback that adds the measurement results into the database"""
self.__init_handle_by_constructor__(
_ffi_api.MeasureCallbackAddToDatabase, # type: ignore # pylint: disable=no-member
)
30 changes: 30 additions & 0 deletions python/tvm/meta_schedule/measure_callback/echo_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A callback that echos the statistics of the tuning process to the console"""
from tvm._ffi import register_object

from .. import _ffi_api
from .measure_callback import MeasureCallback


@register_object("meta_schedule.EchoStatistics")
class EchoStatistics(MeasureCallback):
def __init__(self) -> None:
"""A callback that echos the statistics of the tuning process to the console"""
self.__init_handle_by_constructor__(
_ffi_api.MeasureCallbackEchoStatistics, # type: ignore # pylint: disable=no-member
)
104 changes: 104 additions & 0 deletions python/tvm/meta_schedule/measure_callback/measure_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Meta Schedule MeasureCallback."""

from typing import List, TYPE_CHECKING

from tvm._ffi import register_object
from tvm.runtime import Object

from .. import _ffi_api
from ..builder import BuilderResult
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..utils import _get_hex_address, check_override

if TYPE_CHECKING:
from ..task_scheduler import TaskScheduler


@register_object("meta_schedule.MeasureCallback")
class MeasureCallback(Object):
"""Rules to apply after measure results is available."""

def apply(
self,
task_scheduler: "TaskScheduler",
task_id: int,
measure_candidates: List[MeasureCandidate],
builder_results: List[BuilderResult],
runner_results: List[RunnerResult],
) -> None:
"""Apply a measure callback to the given schedule.
Parameters
----------
task_scheduler: TaskScheduler
The task scheduler.
task_id: int
The task id.
measure_candidates: List[MeasureCandidate]
The measure candidates.
builder_results: List[BuilderResult]
The builder results by building the measure candidates.
runner_results: List[RunnerResult]
The runner results by running the built measure candidates.
"""
return _ffi_api.MeasureCallbackApply( # type: ignore # pylint: disable=no-member
self,
task_scheduler,
task_id,
measure_candidates,
builder_results,
runner_results,
)


@register_object("meta_schedule.PyMeasureCallback")
class PyMeasureCallback(MeasureCallback):
"""An abstract MeasureCallback with customized methods on the python-side."""

def __init__(self):
"""Constructor."""

@check_override(self.__class__, MeasureCallback)
def f_apply(
task_scheduler: "TaskScheduler",
task_id: int,
measure_candidates: List[MeasureCandidate],
builder_results: List[BuilderResult],
runner_results: List[RunnerResult],
) -> None:
return self.apply(
task_scheduler,
task_id,
measure_candidates,
builder_results,
runner_results,
)

def f_as_string() -> str:
return str(self)

self.__init_handle_by_constructor__(
_ffi_api.MeasureCallbackPyMeasureCallback, # type: ignore # pylint: disable=no-member
f_apply,
f_as_string,
)

def __str__(self) -> str:
return f"PyMeasureCallback({_get_hex_address(self.handle)})"
30 changes: 30 additions & 0 deletions python/tvm/meta_schedule/measure_callback/remove_build_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A callback that removes the build artifacts from the disk"""
from tvm._ffi import register_object

from .. import _ffi_api
from .measure_callback import MeasureCallback


@register_object("meta_schedule.RemoveBuildArtifact")
class RemoveBuildArtifact(MeasureCallback):
def __init__(self) -> None:
"""A callback that removes the build artifacts from the disk"""
self.__init_handle_by_constructor__(
_ffi_api.MeasureCallbackRemoveBuildArtifact, # type: ignore # pylint: disable=no-member
)
Loading

0 comments on commit b35fc83

Please sign in to comment.