Skip to content

Commit

Permalink
[AutoScheduler] Python based measure callbacks (apache#7143)
Browse files Browse the repository at this point in the history
* add

* make it work

* format

* add poilcy

* comment

* move test

* format

* fix ci

* Delete useless old code

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
  • Loading branch information
2 people authored and trevor-m committed Jan 21, 2021
1 parent ef9c213 commit a3c2006
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 1 deletion.
29 changes: 29 additions & 0 deletions include/tvm/auto_scheduler/measure.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,35 @@ class MeasureCallback : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
};

/*! \brief A wrapper for measure callback defined by python code
* This class will call functions defined in the python */
class PythonBasedMeasureCallbackNode : public MeasureCallbackNode {
public:
/*! \brief Pointer to the callback funcion in python */
PackedFunc callback_func;

void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
const Array<MeasureResult>& results) final;
static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback";
TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedMeasureCallbackNode, MeasureCallbackNode);
};

/*!
* \brief Managed reference to PythonBasedMeasureCallbackNode.
* \sa PythonBasedMeasureCallbackNode
*/
class PythonBasedMeasureCallback : public MeasureCallback {
public:
/*!
* \brief The constructor.
* \param callback_func The pointer to the callback function defined in python
*/
explicit PythonBasedMeasureCallback(PackedFunc callback_func);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedMeasureCallback, MeasureCallback,
PythonBasedMeasureCallbackNode);
};

// The base class of ProgramBuilders and ProgramRunners.

/*! \brief ProgramBuilder that builds the programs */
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,31 @@ class MeasureCallback(Object):
""" The base class of measurement callback functions. """


@tvm._ffi.register_object("auto_scheduler.PythonBasedMeasureCallback")
class PythonBasedMeasureCallback(MeasureCallback):
"""Base class for measure callbacks implemented in python"""

def __init__(self):
def callback_func(policy, inputs, results):
self.callback(policy, inputs, results)

self.__init_handle_by_constructor__(_ffi_api.PythonBasedMeasureCallback, callback_func)

def callback(self, policy, inputs, results):
"""The callback function.
Parameters
----------
policy: auto_scheduler.search_policy.SearchPolicy
The search policy.
inputs : List[auto_scheduler.measure.MeasureInput]
The measurement inputs
results : List[auto_scheduler.measure.MeasureResult]
The measurement results
"""
raise NotImplementedError


@tvm._ffi.register_object("auto_scheduler.MeasureInput")
class MeasureInput(Object):
"""Store the input of a measurement.
Expand Down
27 changes: 27 additions & 0 deletions src/auto_scheduler/measure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <algorithm>

#include "search_policy/empty_policy.h"
#include "search_policy/sketch_policy.h"
#include "utils.h"

namespace tvm {
Expand All @@ -36,6 +38,7 @@ TVM_REGISTER_NODE_TYPE(MeasureInputNode);
TVM_REGISTER_NODE_TYPE(BuildResultNode);
TVM_REGISTER_NODE_TYPE(MeasureResultNode);
TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
TVM_REGISTER_OBJECT_TYPE(PythonBasedMeasureCallbackNode);
TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode);
Expand Down Expand Up @@ -183,6 +186,25 @@ Array<MeasureResult> RPCRunnerNode::Run(const Array<MeasureInput>& inputs,
return Array<MeasureResult>();
}

/********** MeasureCallback **********/
PythonBasedMeasureCallback::PythonBasedMeasureCallback(PackedFunc callback_func) {
auto node = make_object<PythonBasedMeasureCallbackNode>();
node->callback_func = std::move(callback_func);
data_ = std::move(node);
}

void PythonBasedMeasureCallbackNode::Callback(const SearchPolicy& policy,
const Array<MeasureInput>& inputs,
const Array<MeasureResult>& results) {
if (auto* sketch_policy = static_cast<SketchPolicyNode*>(policy.operator->())) {
callback_func(GetRef<SketchPolicy>(sketch_policy), inputs, results);
} else if (auto* empty_policy = static_cast<EmptyPolicyNode*>(policy.operator->())) {
callback_func(GetRef<EmptyPolicy>(empty_policy), inputs, results);
} else {
LOG(FATAL) << "Unrecognized search policy type. Expect SketchPolicy or EmptyPolicy";
}
}

/********** ProgramMeasurer **********/
ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
Optional<Array<MeasureCallback>> callbacks, int verbose,
Expand Down Expand Up @@ -360,6 +382,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.MeasureResult")
return MeasureResult(costs, error_no, error_msg, all_cost, timestamp);
});

TVM_REGISTER_GLOBAL("auto_scheduler.PythonBasedMeasureCallback")
.set_body_typed([](PackedFunc callback_func) {
return PythonBasedMeasureCallback(callback_func);
});

TVM_REGISTER_GLOBAL("auto_scheduler.ProgramMeasurer")
.set_body_typed([](ProgramBuilder builder, ProgramRunner runner,
Array<MeasureCallback> callbacks, int verbose, int max_continuous_error) {
Expand Down
12 changes: 11 additions & 1 deletion tests/python/unittest/test_auto_scheduler_search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@
import multiprocessing


class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback):
"""A simple Python-based callback for testing."""

def callback(self, policy, inputs, results):
assert isinstance(policy, auto_scheduler.search_policy.SearchPolicy)
for inp, res in zip(inputs, results):
assert isinstance(inp, auto_scheduler.MeasureInput)
assert isinstance(res, auto_scheduler.MeasureResult)


def search_common(
workload=matmul_auto_scheduler_test,
target="llvm",
Expand Down Expand Up @@ -68,7 +78,7 @@ def search_common(
early_stopping=1,
runner=runner,
verbose=2,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()],
)
task.tune(tuning_options=tuning_options, search_policy=search_policy)
sch, args = task.apply_best(log_file)
Expand Down

0 comments on commit a3c2006

Please sign in to comment.