diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 0b8705aafea9f..eb795fa3933a3 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -17,7 +17,7 @@ """MetaSchedule-Relay integration""" from contextlib import contextmanager from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, Set # isort: off from typing_extensions import Literal @@ -120,6 +120,7 @@ def extract_tasks( ), executor: Optional["relay.backend.Executor"] = None, module_equality: str = "structural", + disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -171,6 +172,7 @@ def extract_tasks( with transform.PassContext( opt_level=opt_level, config=pass_config, + disabled_pass=disabled_pass, ): return list(_extract_task(mod, target, params, module_equality)) @@ -250,6 +252,7 @@ def tune_relay( seed: Optional[int] = None, module_equality: str = "structural", num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", + disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ) -> Database: """Tune a Relay program. @@ -306,7 +309,9 @@ def tune_relay( The database that contains the tuning records """ tasks, task_weights = extracted_tasks_to_tune_contexts( - extracted_tasks=extract_tasks(mod, target, params, module_equality=module_equality), + extracted_tasks=extract_tasks( + mod, target, params, module_equality=module_equality, disabled_pass=disabled_pass + ), work_dir=work_dir, space=space, strategy=strategy, @@ -345,6 +350,7 @@ def compile_relay( } ), executor: Optional["relay.backend.Executor"] = None, + disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ): """Compile a relay program with a MetaSchedule database. @@ -387,6 +393,7 @@ def compile_relay( with transform.PassContext( opt_level=opt_level, config=pass_config, + disabled_pass=disabled_pass, ): if backend == "graph": return relay.build(mod, target=target, params=params, executor=executor)