Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Extract task weights during task extraction #10810

Merged
merged 2 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions include/tvm/meta_schedule/integration.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ class ExtractedTaskNode : public runtime::Object {
Target target;
/*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */
Array<IRModule> dispatched;
/*! \brief Weight of the task */
int weight;

void VisitAttrs(AttrVisitor* v) {
v->Visit("task_name", &task_name);
v->Visit("mod", &mod);
v->Visit("target", &target);
v->Visit("dispatched", &dispatched);
v->Visit("weight", &weight);
}

static constexpr const char* _type_key = "meta_schedule.ExtractedTask";
Expand All @@ -66,8 +69,10 @@ class ExtractedTask : public runtime::ObjectRef {
* \brief The high-level IR
* \brief A list of low-level IRs that the high-level IR could potentially dispatch to
*/
explicit ExtractedTask(String task_name, IRModule mod, Target target, Array<IRModule> dispatched);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, ExtractedTaskNode);
explicit ExtractedTask(String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
int weight);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef,
ExtractedTaskNode);
};

/**************** MetaScheduleContext ****************/
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,30 @@ class ExtractedTask(Object):
Target information
dispatched : List[IRModule]
A list of low-level IRs that the high-level IR could potentially dispatch to
weight : int
The weight of the task
"""

task_name: str
mod: IRModule
dispatched: List[IRModule]
weight: int

def __init__(
self,
task_name: str,
mod: IRModule,
target: Target,
dispatched: List[IRModule],
weight: int,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member
task_name,
mod,
target,
dispatched,
weight,
)


Expand Down Expand Up @@ -239,6 +244,4 @@ def extract_task_from_relay(
config=pass_config,
disabled_pass=disabled_pass,
):
tasks = extract_task_func(mod, target, relay_params)
# Tasks are extracted via post order visit, return the reversed list.
return list(reversed(tasks))
return list(extract_task_func(mod, target, relay_params))
9 changes: 5 additions & 4 deletions src/meta_schedule/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ bool HasOnlyOneFunction(const IRModule& mod) {
/**************** ExtractedTask ****************/

ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target,
Array<IRModule> dispatched) {
Array<IRModule> dispatched, int weight) {
ObjectPtr<ExtractedTaskNode> n = make_object<ExtractedTaskNode>();
n->task_name = task_name;
n->mod = mod;
n->target = target;
n->dispatched = dispatched;
n->weight = weight;
data_ = n;
}

Expand Down Expand Up @@ -161,9 +162,9 @@ TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode);
TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode);

TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
.set_body_typed([](String task_name, IRModule mod, Target target,
Array<IRModule> dispatched) -> ExtractedTask {
return ExtractedTask(task_name, mod, target, dispatched);
.set_body_typed([](String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
int weight) -> ExtractedTask {
return ExtractedTask(task_name, mod, target, dispatched, weight);
});
TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextEnterScope")
.set_body_typed(MetaScheduleContextInternal::EnterScope);
Expand Down
44 changes: 27 additions & 17 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,40 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target,
transform::Sequential seq(pass_seqs);
auto opt_mod = seq(std::move(mod));

Array<ExtractedTask> tasks;
std::unordered_set<tec::CCacheKey> cache;
std::unordered_map<std::string, int> name_map;
std::vector<ExtractedTask> tasks;
std::unordered_map<tec::CCacheKey, ExtractedTask> cache;

PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache, &name_map](const Expr& exp) {
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Function relay_func = Downcast<Function>(exp);
if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
return;
}
tec::CCacheKey cache_key(relay_func, target);
if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache.find(cache_key) == cache.end()) {
Array<te::Tensor> inputs_outputs;
std::string fused_name;
std::tie(inputs_outputs, fused_name) =
tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
auto prim_func = tir::CreatePrimFunc(inputs_outputs);
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func}});
auto task_name = tec::GetUniqueName(fused_name, &name_map);
tasks.push_back(ExtractedTask(task_name, relay_mod, target, {tir_mod}));
cache.insert(cache_key);
auto it = cache.find(cache_key);
if (it != cache.end()) {
it->second->weight += 1;
return;
}
Array<te::Tensor> inputs_outputs;
std::string fused_name;
std::tie(inputs_outputs, fused_name) =
tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
auto prim_func = tir::CreatePrimFunc(inputs_outputs);
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func}});
ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1);
tasks.push_back(extracted_task);
cache.emplace(cache_key, extracted_task);
}
});

// Tasks are extracted via post order visit, return the reversed list.
std::reverse(tasks.begin(), tasks.end());
std::unordered_map<std::string, int> name_map;
for (ExtractedTask task : tasks) {
task->task_name = tec::GetUniqueName(task->task_name, &name_map);
}
return tasks;
}

Expand Down
106 changes: 100 additions & 6 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,28 @@
# under the License.
import sys
from typing import List
import numpy as np

import numpy as np
import pytest
import tvm
import tvm.testing
from tvm import relay
from tvm import meta_schedule as ms
from tvm import relay
from tvm.ir.module import IRModule
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
from tvm.meta_schedule.integration import (
ApplyHistoryBest,
ExtractedTask,
MetaScheduleContext,
)
from tvm.meta_schedule.testing import DummyDatabase
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
from tvm.meta_schedule.tune import Parse, extract_task_from_relay
from tvm.meta_schedule.utils import derived_object
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule
from tvm.meta_schedule.testing import DummyDatabase
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
from tvm.meta_schedule.tune import extract_task_from_relay, Parse

# pylint: disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name

Expand Down Expand Up @@ -103,11 +103,105 @@ def test_meta_schedule_integration_extract_from_resnet():
]
]

assert len(extracted_tasks) == 20
assert len(extracted_tasks) == len(expected_task_names)
for t in extracted_tasks:
assert t.task_name in expected_task_names, t.task_name


@requires_torch
def test_meta_schedule_integration_extract_from_bert_base():
expected = {
"fused_nn_dense_2": (
12,
[[64, 3072], [768, 3072], [64, 768]],
),
"fused_nn_dense": (
48,
[[64, 768], [768, 768], [64, 768]],
),
"fused_nn_dense_1": (
12,
[[64, 768], [3072, 768], [64, 3072]],
),
"fused_subtract_add_sqrt_divide_multiply_add": (
25,
[[1, 64, 768], [1, 64, 1], [1, 64, 1], [768], [768], [1, 64, 768]],
),
"fused_nn_batch_matmul": (
24,
[[12, 64, 64], [12, 64, 64], [12, 64, 64]],
),
"fused_reshape_add_add": (
24,
[[64, 768], [768], [1, 64, 768], [1, 64, 768]],
),
"fused_variance": (
25,
[[1, 64, 768], [1, 64, 1], [1, 64, 1]],
),
"fused_mean": (
25,
[[1, 64, 768], [1, 64, 1]],
),
"fused_reshape_add_reshape_transpose_reshape": (
12,
[[64, 768], [768], [12, 64, 64]],
),
"fused_reshape_add_multiply_fast_erf_multiply_add_multiply_reshape": (
12,
[[64, 3072], [3072], [64, 3072]],
),
"fused_nn_fast_softmax": (
12,
[[1, 12, 64, 64], [1, 12, 64, 64]],
),
"fused_reshape_add_reshape_transpose_reshape_1": (
24,
[[64, 768], [768], [12, 64, 64]],
),
"fused_reshape_divide_add": (
12,
[[12, 64, 64], [1, 1, 1, 64], [1, 12, 64, 64]],
),
"fused_reshape_transpose_reshape": (
12,
[[12, 64, 64], [64, 768]],
),
"fused_nn_dense_add_fast_tanh": (
1,
[[1, 768], [768, 768], [1, 768], [1, 768]],
),
"fused_cast_take_add": (
1,
[[1, 64], [30522, 768], [1, 64, 768], [1, 64, 768]],
),
"fused_take": (
1,
[[1, 64, 768], [1, 768]],
),
"fused_reshape": (
12,
[[1, 12, 64, 64], [12, 64, 64]],
),
"fused_reshape_1": (
24,
[[1, 64, 768], [64, 768]],
),
}
mod, params, _ = get_network(name="bert_base", input_shape=[1, 64])
extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params)
assert len(extracted_tasks) == len(expected)
for t in extracted_tasks:
prim_func = None
for _, v in t.dispatched[0].functions.items():
prim_func = v
shape = [[int(x) for x in prim_func.buffer_map[b].shape] for b in prim_func.params]
assert t.task_name in expected
expected_weight, expected_shape = expected[t.task_name]
assert expected_weight == t.weight, t.task_name
assert expected_shape == shape, t.task_name


@requires_torch
def test_meta_schedule_integration_apply_history_best():
mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
Expand Down