Skip to content

Commit

Permalink
[M3c][Meta Schedule] Measure Callbacks (#498)
Browse files Browse the repository at this point in the history
* Add MeasureCallbacks.

* ove measure callback to task scheduler level.
  • Loading branch information
zxybazh authored and junrushao committed Nov 5, 2021
1 parent 5c0f736 commit b64d173
Show file tree
Hide file tree
Showing 16 changed files with 515 additions and 14 deletions.
125 changes: 125 additions & 0 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_

#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/tune_context.h>

namespace tvm {
namespace meta_schedule {

class TaskScheduler;

/*! \brief Rules to apply after measure results is available. */
class MeasureCallbackNode : public runtime::Object {
public:
/*! \brief Virtual destructor. */
virtual ~MeasureCallbackNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief Apply a measure callback rule with given arguments.
* \param task_scheduler The task scheduler.
* \param tasks The list of tune context to process.
* \param measure_candidates The measure candidates.
* \param builds The builder results by building the measure candidates.
* \param results The runner results by running the built measure candidates.
* \return Whether the measure callback was successfully applied.
*/
virtual bool Apply(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) = 0;

static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
};

/*! \brief The measure callback with customized methods on the python-side. */
class PyMeasureCallbackNode : public MeasureCallbackNode {
public:
/*!
* \brief Apply a measure callback to the given schedule.
* \param task_scheduler The task scheduler.
* \param tasks The list of tune context to process.
* \param measure_candidates The measure candidates.
* \param builds The builder results by building the measure candidates.
* \param results The runner results by running the built measure candidates.
* \return Whether the measure callback was successfully applied.
*/
using FApply =
runtime::TypedPackedFunc<bool(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results)>;
/*!
* \brief Get the measure callback function as string with name.
* \return The string of the measure callback function.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `Apply` funcion. */
FApply f_apply;
/*! \brief The packed function to the `AsString` funcion. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_apply` is not visited
// `f_as_string` is not visited
}

bool Apply(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) final {
return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results);
}

static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
};

/*!
* \brief Managed reference to MeasureCallbackNode
* \sa MeasureCallbackNode
*/
class MeasureCallback : public runtime::ObjectRef {
public:
/*!
* \brief Create a measure callback with customized methods on the python-side.
* \param f_apply The packed function of `Apply`.
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, //
PyMeasureCallbackNode::FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
12 changes: 11 additions & 1 deletion include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/measure_callback.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/tune_context.h>

Expand Down Expand Up @@ -73,6 +74,8 @@ class TaskSchedulerNode : public runtime::Object {
Runner runner{nullptr};
/*! \brief The database of the scheduler. */
Database database{nullptr};
/*! \brief The list of measure callbacks of the scheduler. */
Array<MeasureCallback> measure_callbacks;

/*! \brief The default desctructor. */
virtual ~TaskSchedulerNode() = default;
Expand All @@ -82,6 +85,7 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("database", &database);
v->Visit("measure_callbacks", &measure_callbacks);
}

/*! \brief Auto-tuning. */
Expand Down Expand Up @@ -222,8 +226,14 @@ class TaskScheduler : public runtime::ObjectRef {
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database);
Database database, //
Array<MeasureCallback> measure_callbacks);
TVM_DLL static TaskScheduler PyTaskScheduler(
Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
Array<MeasureCallback> measure_callbacks, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/meta_schedule/measure_callback/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
The tvm.meta_schedule.measure_callback package.
"""
from .measure_callback import MeasureCallback, PyMeasureCallback
99 changes: 99 additions & 0 deletions python/tvm/meta_schedule/measure_callback/measure_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Meta Schedule MeasureCallback."""

from typing import TYPE_CHECKING, List

from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.search_strategy import MeasureCandidate
from tvm.meta_schedule.builder import BuilderResult
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.utils import _get_hex_address

from .. import _ffi_api

if TYPE_CHECKING:
from ..tune_context import TuneContext
from ..task_scheduler import TaskScheduler


@register_object("meta_schedule.MeasureCallback")
class MeasureCallback(Object):
"""Rules to apply after measure results is available."""

def apply(
self,
task_scheduler: "TaskScheduler",
tasks: List["TuneContext"],
measure_candidates: List[MeasureCandidate],
builds: List[BuilderResult],
results: List[RunnerResult],
) -> bool:
"""Apply a measure callback to the given schedule.
Parameters
----------
task_scheduler: TaskScheduler
The task scheduler.
tasks: List[TuneContext]
The list of tune context to process.
measure_candidats: List[MeasureCandidate]
The measure candidates.
builds: List[BuilderResult]
The builder results by building the measure candidates.
results: List[RunnerResult]
The runner results by running the built measure candidates.
Returns
-------
result : bool
Whether the measure callback was successfully applied.
"""
return _ffi_api.MeasureCallbackApply(
self, task_scheduler, tasks, measure_candidates, builds, results
)


@register_object("meta_schedule.PyMeasureCallback")
class PyMeasureCallback(MeasureCallback):
"""An abstract MeasureCallback with customized methods on the python-side."""

def __init__(self):
"""Constructor."""

def f_apply(
task_scheduler: "TaskScheduler",
tasks: List[TuneContext],
measure_candidates: List[MeasureCandidate],
builds: List[BuilderResult],
results: List[RunnerResult],
) -> bool:
return self.apply(task_scheduler, tasks, measure_candidates, builds, results)

def f_as_string() -> str:
return str(self)

self.__init_handle_by_constructor__(
_ffi_api.MeasureCallbackPyMeasureCallback, # type: ignore # pylint: disable=no-member
f_apply,
f_as_string,
)

def __str__(self) -> str:
return f"PyMeasureCallback({_get_hex_address(self.handle)})"
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/mutator/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
Parameters
----------
tune_context : TuneContext
The tuning context for initializing the design space generator.
The tuning context for initializing the mutator.
"""
_ffi_api.MutatorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
self, tune_context
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/postproc/postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
Parameters
----------
tune_context : TuneContext
The tuning context for initializing the design space generator.
The tuning context for initializing the post processing.
"""
_ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
self, tune_context
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/schedule_rule/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
Parameters
----------
tune_context : TuneContext
The tuning context for initializing the design space generator.
The tuning context for initializing the schedule rule.
"""
_ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
self, tune_context
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
to generate measure candidates.
"""

from .search_strategy import SearchStrategy, PySearchStrategy
from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate
from .replay_trace import ReplayTrace
from .replay_func import ReplayFunc
5 changes: 5 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import List, TYPE_CHECKING

from tvm._ffi import register_object
from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback

from ..builder import Builder
from ..runner import Runner
Expand All @@ -41,6 +42,7 @@ def __init__(
builder: Builder,
runner: Runner,
database: Database,
measure_callbacks: List[MeasureCallback] = [],
) -> None:
"""Constructor.
Expand All @@ -54,11 +56,14 @@ def __init__(
The runner.
database : Database
The database.
measure_callbacks: List[MeasureCallback]
The list of measure callbacks of the scheduler.
"""
self.__init_handle_by_constructor__(
_ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member
tasks,
builder,
runner,
database,
measure_callbacks,
)
Loading

0 comments on commit b64d173

Please sign in to comment.