From 108ced6f2f8faefcb319de2d25a637af15a776d8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 28 Mar 2022 22:49:49 -0700 Subject: [PATCH 1/2] [MetaSchedule] Extract task weights on task extraction --- include/tvm/meta_schedule/integration.h | 9 +- python/tvm/meta_schedule/integration.py | 9 +- src/meta_schedule/integration.cc | 9 +- src/relay/backend/task_extraction.cc | 44 ++++--- .../test_meta_schedule_integration.py | 110 +++++++++++++++++- 5 files changed, 149 insertions(+), 32 deletions(-) diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h index 56d8d379df93..b231913f2f9b 100644 --- a/include/tvm/meta_schedule/integration.h +++ b/include/tvm/meta_schedule/integration.h @@ -43,12 +43,15 @@ class ExtractedTaskNode : public runtime::Object { Target target; /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */ Array dispatched; + /*! \brief Weight of the task */ + int weight; void VisitAttrs(AttrVisitor* v) { v->Visit("task_name", &task_name); v->Visit("mod", &mod); v->Visit("target", &target); v->Visit("dispatched", &dispatched); + v->Visit("weight", &weight); } static constexpr const char* _type_key = "meta_schedule.ExtractedTask"; @@ -66,8 +69,10 @@ class ExtractedTask : public runtime::ObjectRef { * \brief The high-level IR * \brief A list of low-level IRs that the high-level IR could potentially dispatch to */ - explicit ExtractedTask(String task_name, IRModule mod, Target target, Array dispatched); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, ExtractedTaskNode); + explicit ExtractedTask(String task_name, IRModule mod, Target target, Array dispatched, + int weight); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, + ExtractedTaskNode); }; /**************** MetaScheduleContext ****************/ diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 3c08b21f9511..db6771fecafc 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -45,11 +45,14 @@ class ExtractedTask(Object): Target information dispatched : List[IRModule] A list of low-level IRs that the high-level IR could potentially dispatch to + weight : int + The weight of the task """ task_name: str mod: IRModule dispatched: List[IRModule] + weight: int def __init__( self, @@ -57,6 +60,7 @@ def __init__( mod: IRModule, target: Target, dispatched: List[IRModule], + weight: int, ) -> None: self.__init_handle_by_constructor__( _ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member @@ -64,6 +68,7 @@ def __init__( mod, target, dispatched, + weight, ) @@ -239,6 +244,4 @@ def extract_task_from_relay( config=pass_config, disabled_pass=disabled_pass, ): - tasks = extract_task_func(mod, target, relay_params) - # Tasks are extracted via post order visit, return the reversed list. - return list(reversed(tasks)) + return list(extract_task_func(mod, target, relay_params)) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index f05e07e0f1c1..35c3baf237a4 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -62,12 +62,13 @@ bool HasOnlyOneFunction(const IRModule& mod) { /**************** ExtractedTask ****************/ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, - Array dispatched) { + Array dispatched, int weight) { ObjectPtr n = make_object(); n->task_name = task_name; n->mod = mod; n->target = target; n->dispatched = dispatched; + n->weight = weight; data_ = n; } @@ -161,9 +162,9 @@ TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode); TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") - .set_body_typed([](String task_name, IRModule mod, Target target, - Array dispatched) -> ExtractedTask { - return ExtractedTask(task_name, mod, target, dispatched); + .set_body_typed([](String task_name, IRModule mod, Target target, Array dispatched, + int weight) -> ExtractedTask { + return ExtractedTask(task_name, mod, target, dispatched, weight); }); TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextEnterScope") .set_body_typed(MetaScheduleContextInternal::EnterScope); diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 898e76b81b98..a787f1915099 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -47,30 +47,40 @@ Array ExtractTask(IRModule mod, Target target, transform::Sequential seq(pass_seqs); auto opt_mod = seq(std::move(mod)); - Array tasks; - std::unordered_set cache; - std::unordered_map name_map; + std::vector tasks; + std::unordered_map cache; - PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache, &name_map](const Expr& exp) { + PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); + if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { + return; + } tec::CCacheKey cache_key(relay_func, target); - if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache.find(cache_key) == cache.end()) { - Array inputs_outputs; - std::string fused_name; - std::tie(inputs_outputs, fused_name) = - tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); - auto prim_func = tir::CreatePrimFunc(inputs_outputs); - GlobalVar prim_fn_var(fused_name); - IRModule relay_mod({{prim_fn_var, relay_func}}); - IRModule tir_mod({{prim_fn_var, prim_func}}); - auto task_name = tec::GetUniqueName(fused_name, &name_map); - tasks.push_back(ExtractedTask(task_name, relay_mod, target, {tir_mod})); - cache.insert(cache_key); + auto it = cache.find(cache_key); + if (it != cache.end()) { + it->second->weight += 1; + return; } + Array inputs_outputs; + std::string fused_name; + std::tie(inputs_outputs, fused_name) = + tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); + auto prim_func = tir::CreatePrimFunc(inputs_outputs); + GlobalVar prim_fn_var(fused_name); + IRModule relay_mod({{prim_fn_var, relay_func}}); + IRModule tir_mod({{prim_fn_var, prim_func}}); + ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1); + tasks.push_back(extracted_task); + cache.emplace(cache_key, extracted_task); } }); - + // Tasks are extracted via post order visit, return the reversed list. + std::reverse(tasks.begin(), tasks.end()); + std::unordered_map name_map; + for (ExtractedTask task : tasks) { + task->task_name = tec::GetUniqueName(task->task_name, &name_map); + } return tasks; } diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index d70c5ab1dc0e..b30995cda2cd 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -16,13 +16,13 @@ # under the License. import sys from typing import List -import numpy as np +import numpy as np import pytest import tvm import tvm.testing -from tvm import relay from tvm import meta_schedule as ms +from tvm import relay from tvm.ir.module import IRModule from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload from tvm.meta_schedule.integration import ( @@ -30,14 +30,14 @@ ExtractedTask, MetaScheduleContext, ) +from tvm.meta_schedule.testing import DummyDatabase from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base +from tvm.meta_schedule.tune import Parse, extract_task_from_relay from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule -from tvm.meta_schedule.testing import DummyDatabase -from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base -from tvm.meta_schedule.tune import extract_task_from_relay, Parse # pylint: disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name @@ -103,11 +103,109 @@ def test_meta_schedule_integration_extract_from_resnet(): ] ] - assert len(extracted_tasks) == 20 + assert len(extracted_tasks) == len(expected_task_names) for t in extracted_tasks: assert t.task_name in expected_task_names, t.task_name +@requires_torch +def test_meta_schedule_integration_extract_from_bert_base(): + expected = { + "fused_nn_dense_2": ( + 12, + [[64, 3072], [768, 3072], [64, 768]], + ), + "fused_nn_dense": ( + 48, + [[64, 768], [768, 768], [64, 768]], + ), + "fused_nn_dense_1": ( + 12, + [[64, 768], [3072, 768], [64, 3072]], + ), + "fused_subtract_add_sqrt_divide_multiply_add": ( + 25, + [[1, 64, 768], [1, 64, 1], [1, 64, 1], [768], [768], [1, 64, 768]], + ), + "fused_nn_batch_matmul": ( + 24, + [[12, 64, 64], [12, 64, 64], [12, 64, 64]], + ), + "fused_reshape_add_add": ( + 24, + [[64, 768], [768], [1, 64, 768], [1, 64, 768]], + ), + "fused_variance": ( + 25, + [[1, 64, 768], [1, 64, 1], [1, 64, 1]], + ), + "fused_mean": ( + 25, + [[1, 64, 768], [1, 64, 1]], + ), + "fused_reshape_add_reshape_transpose_reshape": ( + 12, + [[64, 768], [768], [12, 64, 64]], + ), + "fused_reshape_add_multiply_fast_erf_multiply_add_multiply_reshape": ( + 12, + [[64, 3072], [3072], [64, 3072]], + ), + "fused_nn_fast_softmax": ( + 12, + [[1, 12, 64, 64], [1, 12, 64, 64]], + ), + "fused_reshape_add_reshape_transpose_reshape_1": ( + 24, + [[64, 768], [768], [12, 64, 64]], + ), + "fused_reshape_divide_add": ( + 12, + [[12, 64, 64], [1, 1, 1, 64], [1, 12, 64, 64]], + ), + "fused_reshape_transpose_reshape": ( + 12, + [[12, 64, 64], [64, 768]], + ), + "fused_nn_dense_add_fast_tanh": ( + 1, + [[1, 768], [768, 768], [1, 768], [1, 768]], + ), + "fused_cast_take_add": ( + 1, + [[1, 64], [30522, 768], [1, 64, 768], [1, 64, 768]], + ), + "fused_take": ( + 1, + [[1, 64, 768], [1, 768]], + ), + "fused_reshape": ( + 12, + [[1, 12, 64, 64], [12, 64, 64]], + ), + "fused_reshape_1": ( + 24, + [[1, 64, 768], [64, 768]], + ), + } + mod, params, _ = get_network( + name="bert_base", + input_shape=[1, 64], + cache_dir="/root/cache-workloads", + ) + extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) + assert len(extracted_tasks) == len(expected) + for t in extracted_tasks: + prim_func = None + for _, v in t.dispatched[0].functions.items(): + prim_func = v + shape = [[int(x) for x in prim_func.buffer_map[b].shape] for b in prim_func.params] + assert t.task_name in expected + expected_weight, expected_shape = expected[t.task_name] + assert expected_weight == t.weight, t.task_name + assert expected_shape == shape, t.task_name + + @requires_torch def test_meta_schedule_integration_apply_history_best(): mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) From 96b2df3ebc83456797549b69dc57720a53473553 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 28 Mar 2022 23:45:20 -0700 Subject: [PATCH 2/2] Update test_meta_schedule_integration.py --- tests/python/unittest/test_meta_schedule_integration.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index b30995cda2cd..1bbaf35ad280 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -188,11 +188,7 @@ def test_meta_schedule_integration_extract_from_bert_base(): [[1, 64, 768], [64, 768]], ), } - mod, params, _ = get_network( - name="bert_base", - input_shape=[1, 64], - cache_dir="/root/cache-workloads", - ) + mod, params, _ = get_network(name="bert_base", input_shape=[1, 64]) extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) assert len(extracted_tasks) == len(expected) for t in extracted_tasks: