Skip to content

Commit

Permalink
[MetaSchedule] Extract task weights during task extraction (#10810)
Browse files Browse the repository at this point in the history
* [MetaSchedule] Extract task weights on task extraction

* Update test_meta_schedule_integration.py
  • Loading branch information
junrushao authored Mar 29, 2022
1 parent 5306ffa commit c2488ac
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 32 deletions.
9 changes: 7 additions & 2 deletions include/tvm/meta_schedule/integration.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ class ExtractedTaskNode : public runtime::Object {
Target target;
/*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */
Array<IRModule> dispatched;
/*! \brief Weight of the task */
int weight;

void VisitAttrs(AttrVisitor* v) {
v->Visit("task_name", &task_name);
v->Visit("mod", &mod);
v->Visit("target", &target);
v->Visit("dispatched", &dispatched);
v->Visit("weight", &weight);
}

static constexpr const char* _type_key = "meta_schedule.ExtractedTask";
Expand All @@ -66,8 +69,10 @@ class ExtractedTask : public runtime::ObjectRef {
* \brief The high-level IR
* \brief A list of low-level IRs that the high-level IR could potentially dispatch to
*/
explicit ExtractedTask(String task_name, IRModule mod, Target target, Array<IRModule> dispatched);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, ExtractedTaskNode);
explicit ExtractedTask(String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
int weight);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef,
ExtractedTaskNode);
};

/**************** MetaScheduleContext ****************/
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,30 @@ class ExtractedTask(Object):
Target information
dispatched : List[IRModule]
A list of low-level IRs that the high-level IR could potentially dispatch to
weight : int
The weight of the task
"""

task_name: str
mod: IRModule
dispatched: List[IRModule]
weight: int

def __init__(
self,
task_name: str,
mod: IRModule,
target: Target,
dispatched: List[IRModule],
weight: int,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member
task_name,
mod,
target,
dispatched,
weight,
)


Expand Down Expand Up @@ -239,6 +244,4 @@ def extract_task_from_relay(
config=pass_config,
disabled_pass=disabled_pass,
):
tasks = extract_task_func(mod, target, relay_params)
# Tasks are extracted via post order visit, return the reversed list.
return list(reversed(tasks))
return list(extract_task_func(mod, target, relay_params))
9 changes: 5 additions & 4 deletions src/meta_schedule/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ bool HasOnlyOneFunction(const IRModule& mod) {
/**************** ExtractedTask ****************/

ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target,
Array<IRModule> dispatched) {
Array<IRModule> dispatched, int weight) {
ObjectPtr<ExtractedTaskNode> n = make_object<ExtractedTaskNode>();
n->task_name = task_name;
n->mod = mod;
n->target = target;
n->dispatched = dispatched;
n->weight = weight;
data_ = n;
}

Expand Down Expand Up @@ -161,9 +162,9 @@ TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode);
TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode);

TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
.set_body_typed([](String task_name, IRModule mod, Target target,
Array<IRModule> dispatched) -> ExtractedTask {
return ExtractedTask(task_name, mod, target, dispatched);
.set_body_typed([](String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
int weight) -> ExtractedTask {
return ExtractedTask(task_name, mod, target, dispatched, weight);
});
TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextEnterScope")
.set_body_typed(MetaScheduleContextInternal::EnterScope);
Expand Down
44 changes: 27 additions & 17 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,40 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target,
transform::Sequential seq(pass_seqs);
auto opt_mod = seq(std::move(mod));

Array<ExtractedTask> tasks;
std::unordered_set<tec::CCacheKey> cache;
std::unordered_map<std::string, int> name_map;
std::vector<ExtractedTask> tasks;
std::unordered_map<tec::CCacheKey, ExtractedTask> cache;

PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache, &name_map](const Expr& exp) {
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Function relay_func = Downcast<Function>(exp);
if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
return;
}
tec::CCacheKey cache_key(relay_func, target);
if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache.find(cache_key) == cache.end()) {
Array<te::Tensor> inputs_outputs;
std::string fused_name;
std::tie(inputs_outputs, fused_name) =
tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
auto prim_func = tir::CreatePrimFunc(inputs_outputs);
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func}});
auto task_name = tec::GetUniqueName(fused_name, &name_map);
tasks.push_back(ExtractedTask(task_name, relay_mod, target, {tir_mod}));
cache.insert(cache_key);
auto it = cache.find(cache_key);
if (it != cache.end()) {
it->second->weight += 1;
return;
}
Array<te::Tensor> inputs_outputs;
std::string fused_name;
std::tie(inputs_outputs, fused_name) =
tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
auto prim_func = tir::CreatePrimFunc(inputs_outputs);
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func}});
ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1);
tasks.push_back(extracted_task);
cache.emplace(cache_key, extracted_task);
}
});

// Tasks are extracted via post order visit, return the reversed list.
std::reverse(tasks.begin(), tasks.end());
std::unordered_map<std::string, int> name_map;
for (ExtractedTask task : tasks) {
task->task_name = tec::GetUniqueName(task->task_name, &name_map);
}
return tasks;
}

Expand Down
106 changes: 100 additions & 6 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,28 @@
# under the License.
import sys
from typing import List
import numpy as np

import numpy as np
import pytest
import tvm
import tvm.testing
from tvm import relay
from tvm import meta_schedule as ms
from tvm import relay
from tvm.ir.module import IRModule
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
from tvm.meta_schedule.integration import (
ApplyHistoryBest,
ExtractedTask,
MetaScheduleContext,
)
from tvm.meta_schedule.testing import DummyDatabase
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
from tvm.meta_schedule.tune import Parse, extract_task_from_relay
from tvm.meta_schedule.utils import derived_object
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule
from tvm.meta_schedule.testing import DummyDatabase
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
from tvm.meta_schedule.tune import extract_task_from_relay, Parse

# pylint: disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name

Expand Down Expand Up @@ -103,11 +103,105 @@ def test_meta_schedule_integration_extract_from_resnet():
]
]

assert len(extracted_tasks) == 20
assert len(extracted_tasks) == len(expected_task_names)
for t in extracted_tasks:
assert t.task_name in expected_task_names, t.task_name


@requires_torch
def test_meta_schedule_integration_extract_from_bert_base():
expected = {
"fused_nn_dense_2": (
12,
[[64, 3072], [768, 3072], [64, 768]],
),
"fused_nn_dense": (
48,
[[64, 768], [768, 768], [64, 768]],
),
"fused_nn_dense_1": (
12,
[[64, 768], [3072, 768], [64, 3072]],
),
"fused_subtract_add_sqrt_divide_multiply_add": (
25,
[[1, 64, 768], [1, 64, 1], [1, 64, 1], [768], [768], [1, 64, 768]],
),
"fused_nn_batch_matmul": (
24,
[[12, 64, 64], [12, 64, 64], [12, 64, 64]],
),
"fused_reshape_add_add": (
24,
[[64, 768], [768], [1, 64, 768], [1, 64, 768]],
),
"fused_variance": (
25,
[[1, 64, 768], [1, 64, 1], [1, 64, 1]],
),
"fused_mean": (
25,
[[1, 64, 768], [1, 64, 1]],
),
"fused_reshape_add_reshape_transpose_reshape": (
12,
[[64, 768], [768], [12, 64, 64]],
),
"fused_reshape_add_multiply_fast_erf_multiply_add_multiply_reshape": (
12,
[[64, 3072], [3072], [64, 3072]],
),
"fused_nn_fast_softmax": (
12,
[[1, 12, 64, 64], [1, 12, 64, 64]],
),
"fused_reshape_add_reshape_transpose_reshape_1": (
24,
[[64, 768], [768], [12, 64, 64]],
),
"fused_reshape_divide_add": (
12,
[[12, 64, 64], [1, 1, 1, 64], [1, 12, 64, 64]],
),
"fused_reshape_transpose_reshape": (
12,
[[12, 64, 64], [64, 768]],
),
"fused_nn_dense_add_fast_tanh": (
1,
[[1, 768], [768, 768], [1, 768], [1, 768]],
),
"fused_cast_take_add": (
1,
[[1, 64], [30522, 768], [1, 64, 768], [1, 64, 768]],
),
"fused_take": (
1,
[[1, 64, 768], [1, 768]],
),
"fused_reshape": (
12,
[[1, 12, 64, 64], [12, 64, 64]],
),
"fused_reshape_1": (
24,
[[1, 64, 768], [64, 768]],
),
}
mod, params, _ = get_network(name="bert_base", input_shape=[1, 64])
extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params)
assert len(extracted_tasks) == len(expected)
for t in extracted_tasks:
prim_func = None
for _, v in t.dispatched[0].functions.items():
prim_func = v
shape = [[int(x) for x in prim_func.buffer_map[b].shape] for b in prim_func.params]
assert t.task_name in expected
expected_weight, expected_shape = expected[t.task_name]
assert expected_weight == t.weight, t.task_name
assert expected_shape == shape, t.task_name


@requires_torch
def test_meta_schedule_integration_apply_history_best():
mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
Expand Down

0 comments on commit c2488ac

Please sign in to comment.