Skip to content

Commit

Permalink
Squashed commit
Browse files Browse the repository at this point in the history
[Meta Schedule][M3c] Schedule Rules, Mutator & Postprocs (#485)

[Meta Schedule][M3c] PostOrderApply (#486)

Fix Post Order Apply (#490)

[MetaSchedule] Relay Integration (#489)

[M3c][Meta Schedule] Add Trace Correctness Test for PostOrderApply (#492)

Fix replay trace. (#493)

[M3c][Meta Schedule] Implement the Replay Func class. (#495)

[PR] Test script for meta-schedule task extraction. Interface to load… (#494)

[Meta Schedule Refactor] Get child blocks (#500)

Read-at && Write-at (#497)

[M3c][Meta Schedule] Measure Callbacks (#498)

[Bug] Fix Infinite Loop Caused When Calling Methods Not Overrided In PyClass (#496)

[MetaSchedule] Sample-Perfect-Tile (#501)

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com>
  • Loading branch information
8 people committed Nov 5, 2021
1 parent b8fb438 commit a2b4a94
Show file tree
Hide file tree
Showing 84 changed files with 5,343 additions and 246 deletions.
1 change: 1 addition & 0 deletions include/tvm/meta_schedule/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class PyBuilderNode : public BuilderNode {
}

Array<BuilderResult> Build(const Array<BuilderInput>& build_inputs) final {
ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!";
return f_build(build_inputs);
}

Expand Down
23 changes: 17 additions & 6 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,29 @@ class PyDatabaseNode : public DatabaseNode {
// `f_size` is not visited
}

static constexpr const char* _type_key = "meta_schedule.PyDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);

Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); }
Workload CommitWorkload(const IRModule& mod) final {
ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
return f_commit_workload(mod);
}

void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); }
void CommitTuningRecord(const TuningRecord& record) final {
ICHECK(f_commit_tuning_record != nullptr)
<< "PyDatabase's CommitTuningRecord method not implemented!";
f_commit_tuning_record(record);
}

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
return f_get_top_k(workload, top_k);
}

int64_t Size() final { return f_size(); }
int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
}

static constexpr const char* _type_key = "meta_schedule.PyDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);
};

/*!
Expand Down
214 changes: 214 additions & 0 deletions include/tvm/meta_schedule/integration.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_INTEGRATION_H_
#define TVM_META_SCHEDULE_INTEGRATION_H_

#include <tvm/meta_schedule/database.h>
#include <tvm/support/with.h>

#include <unordered_set>

namespace tvm {
namespace meta_schedule {

/**************** ExtractedTask ****************/

/*!
* \brief A tuning task extracted from the high-level IR
*/
class ExtractedTaskNode : public runtime::Object {
public:
/*! \brief The name of the task extracted */
String task_name;
/*! \brief The high-level IR */
IRModule mod;
/*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */
Array<IRModule> dispatched;

void VisitAttrs(AttrVisitor* v) {
v->Visit("task_name", &task_name);
v->Visit("mod", &mod);
v->Visit("dispatched", &dispatched);
}

static constexpr const char* _type_key = "meta_schedule.ExtractedTask";
TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object);
};

/*!
* \brief Managed reference to ExtractedTaskNode
* \sa ExtractedTaskNode
*/
class ExtractedTask : public runtime::ObjectRef {
public:
/*!
* \brief Constructor. The name of the task extracted
* \brief The high-level IR
* \brief A list of low-level IRs that the high-level IR could potentially dispatch to
*/
explicit ExtractedTask(String task_name, IRModule mod, Array<IRModule> dispatched);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, ExtractedTaskNode);
};

/**************** IntegrationContext ****************/

/*!
* \brief A context manager interface for the integration
*/
class IntegrationContextNode : public runtime::Object {
public:
/*! \brief Default destructor */
virtual ~IntegrationContextNode() = default;
/*!
* \brief The entry point of the integration
* \param task_name The name of the task
* \param mod The high-level IR
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to.
* NullOpt means the dispatch needs to be done in the context.
* \return There are different types of the output
* 1) NullOpt if there is no feedback hint
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
*/
virtual Optional<ObjectRef> Query(runtime::String task_name, IRModule mod,
Optional<Array<IRModule>> dispatched) = 0;

static constexpr const char* _type_key = "meta_schedule.IntegrationContext";
TVM_DECLARE_BASE_OBJECT_INFO(IntegrationContextNode, runtime::Object);
};

/*!
* \brief Managed reference to IntegrationContextNode
* \sa IntegrationContextNode
*/
class IntegrationContext : public runtime::ObjectRef {
friend class IntegrationContextInternal;
friend class With<IntegrationContext>;

public:
/*! \brief Default destructor */
virtual ~IntegrationContext() = default;
/*!
* \brief The context manager in the current scope
* \return The IntegrationContext in the current scope. NullOpt if it's currently not under any
* IntegrationContext.
*/
static Optional<IntegrationContext> Current();
/*!
* \brief The entry point of the integration workflow. The compilation process of the high-level
* IR should call this method for task extraction and for feedback hints
* \param task_name The name of the task
* \param mod The high-level IR
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to
* \return There are different types of the output
* 1) NullOpt if there is no feedback hint
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
*/
static Optional<ObjectRef> EntryPoint(runtime::String task_name, IRModule mod,
Optional<Array<IRModule>> dispatched);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IntegrationContext, runtime::ObjectRef,
IntegrationContextNode);

protected:
/*! \brief Default constructor */
IntegrationContext() = default;
/*! \brief Entering the scope of the context manager */
void EnterWithScope();
/*! \brief Exiting the scope of the context manager */
void ExitWithScope();
};

/**************** TaskExtraction ****************/

/*!
* \brief An integration context for task extraction
*/
class TaskExtractionNode : public IntegrationContextNode {
public:
/*! \brief The extracted tasks */
Array<ExtractedTask> tasks{nullptr};

void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); }

// Inherited from base class
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod,
Optional<Array<IRModule>> dispatched) final;

static constexpr const char* _type_key = "meta_schedule.TaskExtraction";
TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, IntegrationContextNode);
};

/*!
* \brief Managed reference to TaskExtractionNode
* \sa TaskExtractionNode
*/
class TaskExtraction : public IntegrationContext {
public:
/*! \brief The path to a cache file storing extracted tasks */
TaskExtraction();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, IntegrationContext,
TaskExtractionNode);
};

/**************** ApplyHistoryBest ****************/

/*!
* \brief An integration context that allows application of historically best records from a
* database
*/
class ApplyHistoryBestNode : public IntegrationContextNode {
public:
/*! \brief The database to be queried from */
Database database{nullptr};

void VisitAttrs(AttrVisitor* v) {
v->Visit("database", &database); //
}

// Inherited from base class
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod,
Optional<Array<IRModule>> dispatched) final;

static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, IntegrationContextNode);
};

/*!
* \brief Managed reference to ApplyHistoryBestNode
* \sa ApplyHistoryBestNode
*/
class ApplyHistoryBest : public IntegrationContext {
public:
/*!
* \brief Constructor
* \param database The database to be queried from
*/
explicit ApplyHistoryBest(Database database);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, IntegrationContext,
ApplyHistoryBestNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_INTEGRATION_H_
126 changes: 126 additions & 0 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_

#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/tune_context.h>

namespace tvm {
namespace meta_schedule {

class TaskScheduler;

/*! \brief Rules to apply after measure results is available. */
class MeasureCallbackNode : public runtime::Object {
public:
/*! \brief Virtual destructor. */
virtual ~MeasureCallbackNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief Apply a measure callback rule with given arguments.
* \param task_scheduler The task scheduler.
* \param tasks The list of tune context to process.
* \param measure_candidates The measure candidates.
* \param builds The builder results by building the measure candidates.
* \param results The runner results by running the built measure candidates.
* \return Whether the measure callback was successfully applied.
*/
virtual bool Apply(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) = 0;

static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
};

/*! \brief The measure callback with customized methods on the python-side. */
class PyMeasureCallbackNode : public MeasureCallbackNode {
public:
/*!
* \brief Apply a measure callback to the given schedule.
* \param task_scheduler The task scheduler.
* \param tasks The list of tune context to process.
* \param measure_candidates The measure candidates.
* \param builds The builder results by building the measure candidates.
* \param results The runner results by running the built measure candidates.
* \return Whether the measure callback was successfully applied.
*/
using FApply =
runtime::TypedPackedFunc<bool(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results)>;
/*!
* \brief Get the measure callback function as string with name.
* \return The string of the measure callback function.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `Apply` funcion. */
FApply f_apply;
/*! \brief The packed function to the `AsString` funcion. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_apply` is not visited
// `f_as_string` is not visited
}

bool Apply(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) final {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results);
}

static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
};

/*!
* \brief Managed reference to MeasureCallbackNode
* \sa MeasureCallbackNode
*/
class MeasureCallback : public runtime::ObjectRef {
public:
/*!
* \brief Create a measure callback with customized methods on the python-side.
* \param f_apply The packed function of `Apply`.
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, //
PyMeasureCallbackNode::FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
Loading

0 comments on commit a2b4a94

Please sign in to comment.