Skip to content

Commit

Permalink
[MetaSchedule] Add "disabled_pass" option in tuning API
Browse files Browse the repository at this point in the history
Now there is no way to disable passes in MetaShedule tuner. This commit adds
new parameter "disabled_pass" in tuning API (tune_relay/compile_relay).
It can be used for different experiments and non default behavoir.
  • Loading branch information
ibsidorenko committed Dec 26, 2022
1 parent e268014 commit 22522d8
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""MetaSchedule-Relay integration"""
from contextlib import contextmanager
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, Set

# isort: off
from typing_extensions import Literal
Expand Down Expand Up @@ -120,6 +120,7 @@ def extract_tasks(
),
executor: Optional["relay.backend.Executor"] = None,
module_equality: str = "structural",
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.
Expand Down Expand Up @@ -171,6 +172,7 @@ def extract_tasks(
with transform.PassContext(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
):
return list(_extract_task(mod, target, params, module_equality))

Expand Down Expand Up @@ -250,6 +252,7 @@ def tune_relay(
seed: Optional[int] = None,
module_equality: str = "structural",
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
) -> Database:
"""Tune a Relay program.
Expand Down Expand Up @@ -306,7 +309,9 @@ def tune_relay(
The database that contains the tuning records
"""
tasks, task_weights = extracted_tasks_to_tune_contexts(
extracted_tasks=extract_tasks(mod, target, params, module_equality=module_equality),
extracted_tasks=extract_tasks(
mod, target, params, module_equality=module_equality, disabled_pass=disabled_pass
),
work_dir=work_dir,
space=space,
strategy=strategy,
Expand Down Expand Up @@ -345,6 +350,7 @@ def compile_relay(
}
),
executor: Optional["relay.backend.Executor"] = None,
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
):
"""Compile a relay program with a MetaSchedule database.
Expand Down Expand Up @@ -387,6 +393,7 @@ def compile_relay(
with transform.PassContext(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
):
if backend == "graph":
return relay.build(mod, target=target, params=params, executor=executor)
Expand Down

0 comments on commit 22522d8

Please sign in to comment.