Skip to content

Commit

Permalink
Clean up task extraction (tlc-pack#92)
Browse files Browse the repository at this point in the history
* Clean up taske extraction

* black
  • Loading branch information
masahi authored and junrushao committed Feb 5, 2023
1 parent a34ad24 commit a2c27c8
Showing 1 changed file with 12 additions and 49 deletions.
61 changes: 12 additions & 49 deletions python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from tvm.tir import PrimFunc
from tvm.relax.expr import Function as RelaxFunc
from tvm.relax.utils import tir_partitioner
from tvm.relax.ty import DynTensorType

from . import _ffi_api
from .database import Database
Expand Down Expand Up @@ -248,14 +247,7 @@ def extract_task_from_relay(
return list(reversed(tasks))


def extract_task_from_relax(
mod: Union[IRModule, RelaxFunc],
target: Target,
*,
opt_level: int = 3,
pass_config: Dict[str, DynTensorType] = {},
disabled_pass: List[str] = [],
) -> List[ExtractedTask]:
def extract_task_from_relax(mod: Union[IRModule, RelaxFunc], target: Target) -> List[ExtractedTask]:
"""Extract tuning tasks from a relax program.
Parameters
Expand All @@ -264,53 +256,24 @@ def extract_task_from_relax(
The module or function to tune
target : tvm.target.Target
The compilation target
opt_level : int
The optimization level of the compiler
pass_config : Dict[str, DynTensorType]
The pass config of the compiler
disabled_pass : List[str]
The list of disabled passes of the compiler
Returns
-------
tasks: List[ExtractedTask]
The tasks extracted from this network
The tasks extracted from this module
"""

@contextmanager
def _autotvm_silencer():
from tvm import autotvm # pylint: disable=import-outside-toplevel

silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True
try:
yield
finally:
autotvm.GLOBAL_SCOPE.silent = silent

def _thread_run(func: Callable[[], None]) -> None:
import threading # pylint: disable=import-outside-toplevel

thread = threading.Thread(target=func)
thread.start()
thread.join()

env = TaskExtraction()
if isinstance(mod, RelaxFunc):
mod = IRModule.from_expr(mod)
if not isinstance(target, Target):
target = Target(target)

def _func():
with env, _autotvm_silencer(), transform.PassContext(
config=pass_config,
disabled_pass=disabled_pass,
opt_level=opt_level,
):
tir_partitions = tir_partitioner(mod)
for tir_mod in tir_partitions:
func_name = tir_mod.get_global_vars()[0].name_hint
MetaScheduleContext.query_inside_with_scope(func_name, tir_mod, target, [tir_mod])

_thread_run(_func)
return env.tasks
tir_partitions = tir_partitioner(mod)

tasks = []
for tir_mod in tir_partitions:
task_name = tir_mod.get_global_vars()[0].name_hint
# The second arg to ExtractedTask is supposed to be a high-level IRModule,
# passing tir_mod as a workaround.
tasks.append(ExtractedTask(task_name, tir_mod, target, [tir_mod]))

return tasks

0 comments on commit a2c27c8

Please sign in to comment.