From eb7289856f061640e0f57dcd22cefce8c517e4e1 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Tue, 12 Mar 2024 06:32:19 +0800 Subject: [PATCH] change register --- gallery/how_to/work_with_msc/using_tools.py | 11 +- .../contrib/msc/core/gym/agent/base_agent.py | 49 +++-- .../tvm/contrib/msc/core/gym/agent/method.py | 11 +- .../msc/core/gym/agent/search_agent.py | 11 +- .../contrib/msc/core/gym/control/configer.py | 12 +- .../msc/core/gym/control/controller.py | 6 +- .../contrib/msc/core/gym/control/service.py | 46 +++-- .../contrib/msc/core/gym/control/worker.py | 32 ++-- .../msc/core/gym/environment/base_env.py | 105 +++++++--- .../msc/core/gym/environment/method.py | 14 +- .../msc/core/gym/environment/prune_env.py | 51 +++-- .../msc/core/gym/environment/quantize_env.py | 58 ++---- python/tvm/contrib/msc/core/gym/namespace.py | 40 ++++ python/tvm/contrib/msc/core/runtime/hook.py | 4 +- python/tvm/contrib/msc/core/runtime/runner.py | 43 ++++- python/tvm/contrib/msc/core/tools/configer.py | 20 +- .../msc/core/tools/distill/distiller.py | 4 +- .../contrib/msc/core/tools/distill/method.py | 6 +- python/tvm/contrib/msc/core/tools/execute.py | 2 +- .../contrib/msc/core/tools/prune/method.py | 6 +- .../contrib/msc/core/tools/prune/pruner.py | 4 +- .../contrib/msc/core/tools/quantize/method.py | 6 +- .../msc/core/tools/quantize/quantizer.py | 4 +- .../contrib/msc/core/tools/track/configer.py | 15 +- .../contrib/msc/core/tools/track/method.py | 4 +- .../contrib/msc/core/tools/track/tracker.py | 4 +- python/tvm/contrib/msc/core/utils/expr.py | 25 ++- python/tvm/contrib/msc/core/utils/file.py | 26 ++- python/tvm/contrib/msc/core/utils/log.py | 8 + python/tvm/contrib/msc/core/utils/message.py | 2 + python/tvm/contrib/msc/core/utils/register.py | 181 ++++++++---------- .../tensorflow/tools/distill/distiller.py | 5 +- .../tensorflow/tools/prune/pruner.py | 5 +- .../tensorflow/tools/quantize/quantizer.py | 5 +- .../tensorflow/tools/track/tracker.py | 5 +- .../msc/framework/tensorrt/runtime/runner.py | 23 +++ .../tensorrt/tools/distill/distiller.py | 5 +- .../framework/tensorrt/tools/prune/pruner.py | 5 +- .../tensorrt/tools/quantize/method.py | 4 +- .../tensorrt/tools/quantize/quantizer.py | 5 +- .../framework/tensorrt/tools/track/tracker.py | 5 +- .../torch/tools/distill/distiller.py | 5 +- .../framework/torch/tools/distill/method.py | 4 +- .../msc/framework/torch/tools/prune/pruner.py | 5 +- .../framework/torch/tools/quantize/method.py | 4 +- .../torch/tools/quantize/quantizer.py | 5 +- .../framework/torch/tools/track/tracker.py | 5 +- .../msc/framework/tvm/runtime/runner.py | 38 +++- .../framework/tvm/tools/distill/distiller.py | 5 +- .../msc/framework/tvm/tools/prune/pruner.py | 5 +- .../framework/tvm/tools/quantize/method.py | 4 +- .../framework/tvm/tools/quantize/quantizer.py | 5 +- .../msc/framework/tvm/tools/track/tracker.py | 5 +- python/tvm/contrib/msc/pipeline/config.py | 14 +- python/tvm/contrib/msc/pipeline/manager.py | 140 ++++++++++---- python/tvm/contrib/msc/pipeline/wrapper.py | 26 ++- 56 files changed, 727 insertions(+), 420 deletions(-) create mode 100644 python/tvm/contrib/msc/core/gym/namespace.py diff --git a/gallery/how_to/work_with_msc/using_tools.py b/gallery/how_to/work_with_msc/using_tools.py index 3c3f528d959d..28cbc4c198bd 100644 --- a/gallery/how_to/work_with_msc/using_tools.py +++ b/gallery/how_to/work_with_msc/using_tools.py @@ -58,7 +58,10 @@ parser.add_argument("--calibrate_iter", type=int, default=100, help="The iter for calibration") parser.add_argument("--train_batch", type=int, default=32, help="The batch size for train") parser.add_argument("--train_iter", type=int, default=200, help="The iter for train") -parser.add_argument("--train_epoch", type=int, default=5, help="The epoch for train") +parser.add_argument("--train_epoch", type=int, default=100, help="The epoch for train") +parser.add_argument( + "--verbose", type=str, default="info", help="The verbose level, info|debug:1,2,3|critical" +) args = parser.parse_args() @@ -86,7 +89,7 @@ def get_config(calib_loader, train_loader): dataset=dataset, tools=tools, skip_config={"all": "check"}, - verbose="info", + verbose=args.verbose, ) @@ -130,3 +133,7 @@ def _get_train_datas(): model.compile() acc = eval_model(model, testloader, max_iter=args.test_iter) print("Compiled acc: " + str(acc)) + + # export the model + path = model.export() + print("Export model to " + str(path)) diff --git a/python/tvm/contrib/msc/core/gym/agent/base_agent.py b/python/tvm/contrib/msc/core/gym/agent/base_agent.py index 801f3f82b430..919118456fbf 100644 --- a/python/tvm/contrib/msc/core/gym/agent/base_agent.py +++ b/python/tvm/contrib/msc/core/gym/agent/base_agent.py @@ -19,6 +19,7 @@ import copy import logging from typing import Dict, Any, List, Tuple +from tvm.contrib.msc.core.gym.namespace import GYMObject from tvm.contrib.msc.core import utils as msc_utils @@ -37,8 +38,6 @@ class BaseAgent(object): The extra options for the agent. debug_level: int The debug level. - verbose: str - The verbose level. logger: logging.Logger The logger """ @@ -50,7 +49,6 @@ def __init__( executors: dict, options: dict = None, debug_level: int = 0, - verbose: str = None, logger: logging.Logger = None, ): self._name = name @@ -58,15 +56,8 @@ def __init__( self._executors = self._parse_executors(msc_utils.copy_dict(executors)) self._options = options or {} self._debug_level = debug_level - if logger: - self._logger = logger - else: - if not verbose: - verbose = "debug" if debug_level > 0 else "info" - self._logger = msc_utils.create_file_logger(verbose, workspace.relpath("AGENT_LOG")) - self._logger.info( - msc_utils.msg_block("AGENT.SETUP({})".format(self.agent_type()), self.setup()) - ) + self._logger = logger or msc_utils.get_global_logger() + self._logger.info(msc_utils.msg_block(self.agent_mark("SETUP"), self.setup())) def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, dict]]: """Parse the executors @@ -85,9 +76,12 @@ def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, di executors = {} for name, raw_config in executors_dict.items(): method_type = ( - raw_config.pop("method_type") if "method_type" in raw_config else "agent.default" + raw_config.pop("method_type") if "method_type" in raw_config else "default" + ) + method_cls = msc_utils.get_registered_gym_method(GYMObject.AGENT, method_type) + assert method_cls, "Can not find method cls for {}:{}".format( + GYMObject.AGENT, method_type ) - method_cls = msc_utils.get_registered_gym_method(method_type) assert "method" in raw_config, "method should be given to find agent method" method_name, method = raw_config.pop("method"), None if hasattr(method_cls, method_name): @@ -244,7 +238,7 @@ def learn(self): The learned rewards. """ - self._logger.debug(msc_utils.msg_block("AGENT.LEARN", self._knowledge)) + self._logger.debug(msc_utils.msg_block(self.agent_mark("KNOWLEDEG"), self._knowledge)) return self._learn() def _learn(self): @@ -306,9 +300,26 @@ def _evaluate(self, reward: dict) -> float: return self._execute("evaluate", self._baseline, reward) - @classmethod - def agent_type(cls): - return "base" + def agent_mark(self, msg: Any) -> str: + """Mark the message with agent info + + Parameters + ------- + msg: str + The message + Returns + ------- + msg: str + The message with mark. + """ + + return "AGENT({}) {}".format(self.role_type(), msg) -msc_utils.register_gym_agent(BaseAgent) + @classmethod + def role(cls): + return GYMObject.AGENT + + @classmethod + def role_type(cls): + return "base" diff --git a/python/tvm/contrib/msc/core/gym/agent/method.py b/python/tvm/contrib/msc/core/gym/agent/method.py index 988fb23f69d6..af9c3cbe91a9 100644 --- a/python/tvm/contrib/msc/core/gym/agent/method.py +++ b/python/tvm/contrib/msc/core/gym/agent/method.py @@ -18,9 +18,11 @@ """tvm.contrib.msc.core.gym.agent.method""" from typing import Any +from tvm.contrib.msc.core.gym.namespace import GYMObject from tvm.contrib.msc.core import utils as msc_utils +@msc_utils.register_gym_method class AgentMethod(object): """Default prune method""" @@ -73,8 +75,9 @@ def evaluate_by_thresh(cls, agent: Any, baseline: dict, reward: dict, thresh: fl return reward["reward"] @classmethod - def method_type(cls): - return "agent.default" - + def role(cls): + return GYMObject.AGENT -msc_utils.register_gym_method(AgentMethod) + @classmethod + def method_type(cls): + return "default" diff --git a/python/tvm/contrib/msc/core/gym/agent/search_agent.py b/python/tvm/contrib/msc/core/gym/agent/search_agent.py index 8b9bc176ab47..743c3a1f752c 100644 --- a/python/tvm/contrib/msc/core/gym/agent/search_agent.py +++ b/python/tvm/contrib/msc/core/gym/agent/search_agent.py @@ -37,10 +37,11 @@ def setup(self) -> dict: return super().setup() @classmethod - def agent_type(cls): + def role_type(cls): return "search.base" +@msc_utils.register_gym_object class GridSearchAgent(BaseSearchAgent): """GridSearch agent""" @@ -92,10 +93,11 @@ def _learn(self): return best_actions, best_rewards @classmethod - def agent_type(cls): + def role_type(cls): return "search.grid" +@msc_utils.register_gym_object class BinarySearchAgent(BaseSearchAgent): """BinarySearch agent""" @@ -173,8 +175,5 @@ def _learn(self): return actions, rewards @classmethod - def agent_type(cls): + def role_type(cls): return "search.binary" - - -msc_utils.register_gym_agent(GridSearchAgent) diff --git a/python/tvm/contrib/msc/core/gym/control/configer.py b/python/tvm/contrib/msc/core/gym/control/configer.py index 00cb54cfd39a..89f2f82d179f 100644 --- a/python/tvm/contrib/msc/core/gym/control/configer.py +++ b/python/tvm/contrib/msc/core/gym/control/configer.py @@ -48,6 +48,7 @@ def update(self, raw_config: dict) -> dict: raise NotImplementedError("update is not implemented in BaseConfiger") +@msc_utils.register_gym_configer class DefaultConfiger(BaseConfiger): """Default configer for gym""" @@ -67,10 +68,10 @@ def update(self, raw_config: dict) -> dict: config = msc_utils.copy_dict(raw_config) assert "env" in config and "agent" in config, "env and agent should be given to run gym" - if "env_type" not in config["env"]: - config["env"]["env_type"] = self._stage + ".default" - if "agent_type" not in config["agent"]: - config["agent"]["agent_type"] = "search.grid" + if "role_type" not in config["env"]: + config["env"]["role_type"] = self._stage + ".default" + if "role_type" not in config["agent"]: + config["agent"]["role_type"] = "search.grid" if "executors" not in config["env"]: config["env"]["executors"] = {} # update executors @@ -92,6 +93,3 @@ def update(self, raw_config: dict) -> dict: @classmethod def config_type(cls): return "default" - - -msc_utils.register_gym_configer(DefaultConfiger) diff --git a/python/tvm/contrib/msc/core/gym/control/controller.py b/python/tvm/contrib/msc/core/gym/control/controller.py index 17ca5edb1c0a..c0a6248ce3b6 100644 --- a/python/tvm/contrib/msc/core/gym/control/controller.py +++ b/python/tvm/contrib/msc/core/gym/control/controller.py @@ -17,9 +17,9 @@ """tvm.contrib.msc.core.gym.control.controller""" from typing import Dict, Any +from tvm.contrib.msc.core.gym.namespace import GYMObject, GYMAction from tvm.contrib.msc.core import utils as msc_utils from .service import MainService, NodeService -from .namespace import GYMObject, GYMAction class BaseController(object): @@ -98,10 +98,8 @@ def create_controller(stage: str, config: dict, extra_config: dict = None): return controller_cls(msc_utils.get_gym_dir(), config) +@msc_utils.register_gym_controller class DefaultController(BaseController): @classmethod def control_type(cls): return "default" - - -msc_utils.register_gym_controller(DefaultController) diff --git a/python/tvm/contrib/msc/core/gym/control/service.py b/python/tvm/contrib/msc/core/gym/control/service.py index f8fbdd31ddf6..06685c020be9 100644 --- a/python/tvm/contrib/msc/core/gym/control/service.py +++ b/python/tvm/contrib/msc/core/gym/control/service.py @@ -25,9 +25,9 @@ import queue import numpy as np +from tvm.contrib.msc.core.gym.namespace import GYMObject, GYMAction from tvm.contrib.msc.core import utils as msc_utils -from .worker import BaseWorker, WorkerFactory -from .namespace import GYMObject, GYMAction +from .worker import BaseGymWorker, WorkerFactory def _send_message(msg_queue: queue.Queue, header: str, body: dict, header_type: str = "message"): @@ -149,10 +149,8 @@ class BaseService(object): The max seatch iter. record_step: int The record step. - debug_level: int - The debug level verbose: str - The verbose level. + The verbose level """ def __init__( @@ -170,15 +168,13 @@ def __init__( ): self._workspace = workspace tasks = tasks or [GYMObject.ENV + ":0", GYMObject.AGENT + ":0"] - if not verbose: - verbose = "debug" if debug_level > 0 else "info" + verbose = verbose or "info" + debug_level = int(verbose.split(":")[1]) if verbose.startswith("debug:") else 0 self._logger = msc_utils.create_file_logger(verbose, self._workspace.relpath("SERVICE_LOG")) - def _create_workers(config: dict, obj_type: str) -> List[BaseWorker]: + def _create_workers(config: dict, obj_type: str) -> List[BaseGymWorker]: if "debug_level" not in config: config["debug_level"] = debug_level - if "verbose" not in config: - config["verbose"] = verbose if "logger" not in config: config["logger"] = self._logger return [ @@ -192,9 +188,7 @@ def _create_workers(config: dict, obj_type: str) -> List[BaseWorker]: self._max_iter = max_iter self._record_step = record_step self._debug_level = debug_level - self._logger.info( - msc_utils.msg_block("SERVICE.SETUP({})".format(self.service_type), self.setup()) - ) + self._logger.info(msc_utils.msg_block(self.service_mark("SETUP"), self.setup())) def setup(self) -> dict: """Setup the tool @@ -242,8 +236,8 @@ def reset(self): self._task_id, self._states = 0, [] self._iter_done = False self._logger.info("SERVICE Reset %d/%d th iter", self._iter_id, self._max_iter) - self.execute(GYMObject.AGENT, GYMAction.RESET) self.execute(GYMObject.ENV, GYMAction.RESET) + self.execute(GYMObject.AGENT, GYMAction.RESET) def learn(self): self.execute(GYMObject.AGENT, GYMAction.LEARN) @@ -387,9 +381,9 @@ def _process_request(self, msg_key: str) -> dict: workers = {w.worker_id: w for w in self._get_workers(obj_type)} requests = self._wait_request(msg_key) if act_type in (GYMAction.INIT, GYMAction.RESET): - mark = "I[{}/{}] {}.{}".format(self._iter_id, self._max_iter, obj_type, act_type) + mark = "Iter[{}/{}] {}.{}".format(self._iter_id, self._max_iter, obj_type, act_type) else: - mark = "I[{}/{}].T[{}/{}] {}.{}".format( + mark = "Iter[{}/{}] Task[{}/{}] {}.{}".format( self._iter_id, self._max_iter, self._task_id, self._max_task, obj_type, act_type ) requests = {int(k): v for k, v in requests.items()} @@ -400,7 +394,7 @@ def _process_request(self, msg_key: str) -> dict: "requests": {workers[w].name: r for w, r in requests.items()}, "responses": {workers[w].name: r for w, r in responses.items()}, } - self._logger.info(msc_utils.msg_table(mark, info)) + self._logger.info(msc_utils.msg_block(mark, info, symbol="=")) return responses def _process_response(self, msg_key: str, response: dict): @@ -464,7 +458,7 @@ def _from_msg_key(self, msg_key: str) -> Tuple[str, str]: return msg_key.split("-s-") - def _get_workers(self, obj_type: str) -> List[BaseWorker]: + def _get_workers(self, obj_type: str) -> List[BaseGymWorker]: """Get workers according to obj_type Parameters @@ -519,6 +513,22 @@ def _get_world_ids(self, obj_type: str) -> List[int]: return self._agent_world_ids return [] + def service_mark(self, msg: Any) -> str: + """Mark the message with service info + + Parameters + ------- + msg: str + The message + + Returns + ------- + msg: str + The message with mark. + """ + + return "SERIVCE({}) {}".format(self.service_type, msg) + @property def done(self): return self._done diff --git a/python/tvm/contrib/msc/core/gym/control/worker.py b/python/tvm/contrib/msc/core/gym/control/worker.py index 7ccfb5da38e2..235a228c89f9 100644 --- a/python/tvm/contrib/msc/core/gym/control/worker.py +++ b/python/tvm/contrib/msc/core/gym/control/worker.py @@ -17,11 +17,11 @@ """tvm.contrib.msc.core.gym.control.worker""" from typing import Any +from tvm.contrib.msc.core.gym.namespace import GYMObject, GYMAction from tvm.contrib.msc.core import utils as msc_utils -from .namespace import GYMObject, GYMAction -class BaseWorker(object): +class BaseGymWorker(object): """Basic worker for gym Parameters @@ -78,7 +78,7 @@ def execute(self, act_type: str, **kwargs) -> Any: The execute result. """ - raise NotImplementedError("execute is not implemented in BaseWorker") + raise NotImplementedError("execute is not implemented in " + str(self.__class__)) @property def obj_type(self): @@ -93,7 +93,7 @@ def worker_id(self): return self._worker_id -class EnvWorker(BaseWorker): +class EnvGymWorker(BaseGymWorker): """Env worker for gym""" def execute(self, act_type: str, **kwargs) -> Any: @@ -136,8 +136,8 @@ def obj_type(self): return GYMObject.ENV -class AgentWorker(BaseWorker): - """Env worker for gym""" +class AgentGymWorker(BaseGymWorker): + """Agent worker for gym""" def execute(self, act_type: str, **kwargs) -> Any: """Execute the worker @@ -182,7 +182,7 @@ class WorkerFactory(object): """The Factory for workers""" @classmethod - def create(cls, name: str, workspace: msc_utils.MSCDirectory, config: dict) -> BaseWorker: + def create(cls, name: str, workspace: msc_utils.MSCDirectory, config: dict) -> BaseGymWorker: """Create worker Parameters @@ -200,17 +200,21 @@ def create(cls, name: str, workspace: msc_utils.MSCDirectory, config: dict) -> B Returns ------- - worker: BaseWorker + worker: BaseGymWorker The create worker. """ + def _get_worker_cls(obj: str): + worker_type = config.pop("role_type") if "role_type" in config else "default" + worker_cls = msc_utils.get_registered_gym_object(obj, worker_type) + assert worker_cls, "Can not find worker class for {}:{}".format(obj, worker_type) + return worker_cls + obj_type, worker_id = name.split(":") if obj_type == GYMObject.ENV: - env_type = config.pop("env_type") if "env_type" in config else "default" - worker_cls = msc_utils.get_registered_gym_env(env_type) - return EnvWorker(name, workspace, int(worker_id), worker_cls, config) + worker_cls = _get_worker_cls(obj_type) + return EnvGymWorker(name, workspace, int(worker_id), worker_cls, config) if obj_type == GYMObject.AGENT: - agent_type = config.pop("agent_type") if "agent_type" in config else "default" - worker_cls = msc_utils.get_registered_gym_agent(agent_type) - return AgentWorker(name, workspace, int(worker_id), worker_cls, config) + worker_cls = _get_worker_cls(obj_type) + return AgentGymWorker(name, workspace, int(worker_id), worker_cls, config) raise TypeError("Worker for {} is not supported".format(obj_type)) diff --git a/python/tvm/contrib/msc/core/gym/environment/base_env.py b/python/tvm/contrib/msc/core/gym/environment/base_env.py index 86f1bff7be89..300b000dcf60 100644 --- a/python/tvm/contrib/msc/core/gym/environment/base_env.py +++ b/python/tvm/contrib/msc/core/gym/environment/base_env.py @@ -18,7 +18,8 @@ import copy import logging -from typing import Dict, Any, List, Tuple +from typing import Dict, Any, List, Tuple, Union +from tvm.contrib.msc.core.gym.namespace import GYMObject from tvm.contrib.msc.core.runtime import BaseRunner from tvm.contrib.msc.core.tools import BaseTool from tvm.contrib.msc.core import utils as msc_utils @@ -43,8 +44,6 @@ class BaseEnv(object): The extra options for the environment. debug_level: int The debug level. - verbose: str - The verbose level. logger: logging.Logger The logger """ @@ -60,27 +59,19 @@ def __init__( options: dict = None, max_tasks: int = -1, debug_level: int = 0, - verbose: str = None, logger: logging.Logger = None, ): self._name = name self._runner = runner self._data_loader = data_loader self._workspace = workspace - self._knowledge = knowledge + self._knowledge = msc_utils.load_dict(knowledge) self._executors = self._parse_executors(msc_utils.copy_dict(executors)) self._options = options or {} self._max_tasks = max_tasks self._debug_level = debug_level - if logger: - self._logger = logger - else: - if not verbose: - verbose = "debug" if debug_level > 0 else "info" - self._logger = msc_utils.create_file_logger(verbose, workspace.relpath("ENV_LOG")) - self._logger.info( - msc_utils.msg_block("ENV.SETUP({})".format(self.env_type()), self.setup()) - ) + self._logger = logger or msc_utils.get_global_logger() + self._logger.info(msc_utils.msg_block(self.env_mark("SETUP"), self.setup())) def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, dict]]: """Parse the executors @@ -99,9 +90,12 @@ def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, di executors = {} for name, raw_config in executors_dict.items(): method_type = ( - raw_config.pop("method_type") if "method_type" in raw_config else "env.default" + raw_config.pop("method_type") if "method_type" in raw_config else "default" + ) + method_cls = msc_utils.get_registered_gym_method(GYMObject.ENV, method_type) + assert method_cls, "Can not find method cls for {}:{}".format( + GYMObject.ENV, method_type ) - method_cls = msc_utils.get_registered_gym_method(method_type) assert "method" in raw_config, "method should be given to find enviironment method" method_name, method = raw_config.pop("method"), None if hasattr(method_cls, method_name): @@ -122,6 +116,7 @@ def setup(self) -> dict: """ self._cache_dir = self._workspace.create_dir("Cache") + self._tool = None self._tasks = [] return { "name": self._name, @@ -155,11 +150,11 @@ def init(self) -> Tuple[int, Dict[str, Any]]: self._tasks = self._tasks[: self._max_tasks] # get baseline self._tool.disable() - self._runner.build(self._cache_dir, force_build=True) + self._runner.build(self._cache_dir, force_build=True, disable_tools=[self._tool.tool_type]) baseline = self._reward_runner(-1) self._tool.enable() tasks_info = {"tasks_num": len(self._tasks), "tasks": self._tasks} - self._logger.info(msc_utils.msg_block("ENV.TASKS", tasks_info, width=0)) + self._logger.info(msc_utils.msg_block(self.env_mark("TASKS"), tasks_info)) return len(self._tasks), baseline def _init_tool(self) -> BaseTool: @@ -274,7 +269,7 @@ def summary(self, actions: List[dict], rewards: List[dict]) -> dict: self._logger.info("Env Summary with %d actions, %d rewards", len(actions), len(rewards)) return self._summary(actions, rewards) - def _summary(self, actions: List[dict], rewards: List[dict]) -> dict: + def _summary(self, actions: List[dict], rewards: List[dict]) -> Union[dict, str]: """Summary the final plan Parameters @@ -286,12 +281,54 @@ def _summary(self, actions: List[dict], rewards: List[dict]) -> dict: Returns ------- - plan: dict - The final plan. + knowledge: dict| str + The learned knowledge or file. """ raise NotImplementedError("_summary is not implemented in BaseEnv") + def _update_strategy(self, strategy: dict, **kwargs) -> dict: + """Update startegy + + Parameters + ---------- + startegy: dict + The strategy. + kwargs: dict + The kwargs. + + Returns + ------- + strategy: dict + The updated strategy. + """ + + for t_type, method_def in strategy["methods"].items(): + if isinstance(method_def, str): + strategy["methods"][t_type] = {"method_name": method_def, **kwargs} + elif isinstance(method_def, dict): + method_def.update(kwargs) + return strategy + + def _get_strategy(self, action: dict, task_id: int) -> dict: + """Get strategy from task_id + + Parameters + ---------- + action: float + The current action. + task_id: int + The current task id. + + Returns + ------- + strategy: dict + The strategy. + """ + + strategy = msc_utils.copy_dict(self.get_task(task_id)) + return self._update_strategy(strategy, **action) + def get_task(self, task_id: int) -> dict: """Get task according to task_id @@ -363,6 +400,30 @@ def _execute(self, name: str, *args, **kwargs) -> Any: kwargs.update({k: v for k, v in config.items() if k not in kwargs}) return method(self, *args, **kwargs) + def env_mark(self, msg: Any) -> str: + """Mark the message with env info + + Parameters + ------- + msg: str + The message + + Returns + ------- + msg: str + The message with mark. + """ + + return "ENV({}) {}".format(self.role_type(), msg) + + @property + def tool(self): + return self._tool + + @classmethod + def role(cls): + return GYMObject.ENV + @classmethod - def env_type(cls): + def role_type(cls): return "base" diff --git a/python/tvm/contrib/msc/core/gym/environment/method.py b/python/tvm/contrib/msc/core/gym/environment/method.py index 66fe573d932f..405318c447d9 100644 --- a/python/tvm/contrib/msc/core/gym/environment/method.py +++ b/python/tvm/contrib/msc/core/gym/environment/method.py @@ -20,11 +20,13 @@ from typing import Any, List import numpy as np +from tvm.contrib.msc.core.gym.namespace import GYMObject from tvm.contrib.msc.core.runtime import BaseRunner from tvm.contrib.msc.core.tools import BaseTool from tvm.contrib.msc.core import utils as msc_utils +@msc_utils.register_gym_method class EnvMethod(object): """Default prune method""" @@ -189,14 +191,16 @@ def action_quantize_scale( """ task = env.get_task(task_id) + plan = env.tool.plan[task["tensor_ids"][0]] return [ - {"scale": task["scale"] * a} + {"scale": plan["scale"] * a} for a in cls.action_linear_space(env, task_id, start, end, step) ] @classmethod - def method_type(cls): - return "env.default" - + def role(cls): + return GYMObject.ENV -msc_utils.register_gym_method(EnvMethod) + @classmethod + def method_type(cls): + return "default" diff --git a/python/tvm/contrib/msc/core/gym/environment/prune_env.py b/python/tvm/contrib/msc/core/gym/environment/prune_env.py index 8f8a53567ef8..eaff86885ec2 100644 --- a/python/tvm/contrib/msc/core/gym/environment/prune_env.py +++ b/python/tvm/contrib/msc/core/gym/environment/prune_env.py @@ -16,12 +16,13 @@ # under the License. """tvm.contrib.msc.core.gym.prune_env""" -from typing import List +from typing import List, Union from tvm.contrib.msc.core.tools import BaseTool, ToolType from tvm.contrib.msc.core import utils as msc_utils from .base_env import BaseEnv +@msc_utils.register_gym_object class PruneEnv(BaseEnv): """Environment for prune""" @@ -29,10 +30,11 @@ def _init_tool(self) -> BaseTool: """Get the main tool""" config = self._runner.get_tool_config(ToolType.PRUNER) - self._meta_strategys = config["strategys"] - for s in self._meta_strategys: - s.update({"density": 1}) - return self._runner.get_tool(ToolType.PRUNER) + self._meta_strategys = msc_utils.copy_dict(config["strategys"]) + self._meta_strategys = [self._update_strategy(s, density=1) for s in self._meta_strategys] + tool = self._runner.get_tool(ToolType.PRUNER) + tool.change_strategys(self._meta_strategys) + return tool def _update_tool(self, action: dict, task_id: int): """Update the tool @@ -46,9 +48,9 @@ def _update_tool(self, action: dict, task_id: int): """ task_strategy = self._get_strategy(action, task_id) - self._tool.plan_by_strategys(self._meta_strategys + [task_strategy]) + self._apply_strategys(self._meta_strategys + [task_strategy]) - def _summary(self, actions: List[dict], rewards: List[dict]) -> dict: + def _summary(self, actions: List[dict], rewards: List[dict]) -> Union[dict, str]: """Summary the final plan Parameters @@ -60,36 +62,33 @@ def _summary(self, actions: List[dict], rewards: List[dict]) -> dict: Returns ------- - plan: dict - The final plan. + knowledge: dict| str + The learned knowledge or file. """ - strategys = [self._get_strategy(act, idx) for idx, act in enumerate(actions)] - return self._tool.plan_by_strategys(self._meta_strategys + strategys) + strategys = self._meta_strategys + [ + self._get_strategy(act, idx) for idx, act in enumerate(actions) + ] + return self._apply_strategys(strategys) - def _get_strategy(self, action: dict, task_id: int) -> dict: - """Get strategy from task_id + def _apply_strategys(self, strategys: List[dict]) -> str: + """Apply the strategys Parameters ---------- - action: float - The current action. - task_id: int - The current task id. + strategys: list + The given strategys Returns ------- - strategy: dict - The strategy. + plan_file: str + The plan after strategys applied. """ - strategy = msc_utils.copy_dict(self.get_task(task_id)) - strategy.update(**action) - return strategy + self._tool.change_strategys(strategys) + self._runner.build(self._cache_dir, force_build=True) + return self._runner.make_plan(self._tool.tool_type(), self._data_loader) @classmethod - def env_type(cls): + def role_type(cls): return msc_utils.MSCStage.PRUNE + ".default" - - -msc_utils.register_gym_env(PruneEnv) diff --git a/python/tvm/contrib/msc/core/gym/environment/quantize_env.py b/python/tvm/contrib/msc/core/gym/environment/quantize_env.py index 0a5210b83032..72dee8e5de67 100644 --- a/python/tvm/contrib/msc/core/gym/environment/quantize_env.py +++ b/python/tvm/contrib/msc/core/gym/environment/quantize_env.py @@ -16,22 +16,20 @@ # under the License. """tvm.contrib.msc.core.gym.quantize_env""" -import os -from typing import List +from typing import List, Union from tvm.contrib.msc.core.tools import BaseTool, ToolType from tvm.contrib.msc.core import utils as msc_utils from .base_env import BaseEnv +@msc_utils.register_gym_object class QuantizeEnv(BaseEnv): """Environment for quantize""" def _init_tool(self) -> BaseTool: """Get the main tool""" - plan_file = self._runner.apply_tool(ToolType.QUANTIZER, self._data_loader) - self._meta_plan = msc_utils.load_dict(plan_file) - os.remove(plan_file) + self._runner.make_plan(ToolType.QUANTIZER, self._data_loader) return self._runner.get_tool(ToolType.QUANTIZER) def _update_tool(self, action: dict, task_id: int): @@ -45,11 +43,9 @@ def _update_tool(self, action: dict, task_id: int): The current task id. """ - plan = msc_utils.copy_dict(self._meta_plan) - plan.update(self._get_plan(action, task_id)) - self._tool.set_plan(plan) + self._tool.change_strategys([self._get_strategy(action, task_id)]) - def _summary(self, actions: List[dict], rewards: List[dict]) -> dict: + def _summary(self, actions: List[dict], rewards: List[dict]) -> Union[dict, str]: """Summary the final plan Parameters @@ -61,39 +57,21 @@ def _summary(self, actions: List[dict], rewards: List[dict]) -> dict: Returns ------- - plan: dict - The final plan. + knowledge: dict| str + The learned knowledge or file. """ - plan = msc_utils.copy_dict(self._meta_plan) - for idx, act in enumerate(actions): - plan.update(self._get_plan(act, idx)) - return plan - - def _get_plan(self, action: dict, task_id: int) -> dict: - """Get plan from task_id - - Parameters - ---------- - action: float - The current action. - task_id: int - The current task id. - - Returns - ------- - plan: dict - The plan. - """ - - plan = msc_utils.copy_dict(self.get_task(task_id)) - plan.update(**action) - name = plan.pop("name") - return {name: plan} + strategys = self.tool._parse_strategys( + [self._get_strategy(act, idx) for idx, act in enumerate(actions)] + ) + plan = self.tool.plan + for name, info in plan.items(): + if name not in strategys: + continue + info.update(strategys[name].get_executor(msc_utils.MSCStage.QUANTIZE).config) + summary_file = msc_utils.get_cache_dir().relpath("gym_summary.json") + return msc_utils.dump_dict(plan, summary_file) @classmethod - def env_type(cls): + def role_type(cls): return msc_utils.MSCStage.QUANTIZE + ".default" - - -msc_utils.register_gym_env(QuantizeEnv) diff --git a/python/tvm/contrib/msc/core/gym/namespace.py b/python/tvm/contrib/msc/core/gym/namespace.py new file mode 100644 index 000000000000..584316ef3a34 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/namespace.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.namespace""" + + +class GYMObject(object): + """Enum all gym objects""" + + BASE = "base" + ENV = "env" + AGENT = "agent" + SERVICE = "service" + + +class GYMAction(object): + """Enum all gym actions""" + + INIT = "init" + RESET = "reset" + GET_STATE = "get_state" + CHOOSE_ACTION = "choose_action" + STEP = "step" + STORE = "store" + LEARN = "learn" + SUMMARY = "summary" + CLEANUP = "cleanup" diff --git a/python/tvm/contrib/msc/core/runtime/hook.py b/python/tvm/contrib/msc/core/runtime/hook.py index 1229697a63fb..e129d9771b02 100644 --- a/python/tvm/contrib/msc/core/runtime/hook.py +++ b/python/tvm/contrib/msc/core/runtime/hook.py @@ -128,6 +128,7 @@ def name(cls): return "customized" +@msc_utils.register_runner_hook class UpdateWeightsHook(RunnerHook): """Hook for update weights""" @@ -191,6 +192,3 @@ def load_runner_hook(config: dict) -> Any: if hook_cls: return hook_cls(hook_config) return CustomizedHook(hook_ref, hook_config) - - -msc_utils.register_runner_hook(UpdateWeightsHook) diff --git a/python/tvm/contrib/msc/core/runtime/runner.py b/python/tvm/contrib/msc/core/runtime/runner.py index c4f4016d148f..e4a9aaa1d39b 100644 --- a/python/tvm/contrib/msc/core/runtime/runner.py +++ b/python/tvm/contrib/msc/core/runtime/runner.py @@ -550,6 +550,22 @@ def export_module(self, folder: msc_utils.MSCDirectory) -> tvm.IRModule: raise NotImplementedError("export_module is not supported in BaseRunner") + def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: + """Export the runnable + + Parameters + ------- + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The runnable info. + """ + + raise NotImplementedError("export_runnable is not supported in BaseRunner") + def train(self): """Change status to train""" @@ -1216,6 +1232,7 @@ def setup(self) -> dict: """ self._byoc_mod, self._byoc_graph = None, None + self._executable = None return super().setup() def visualize(self, visual_dir: msc_utils.MSCDirectory): @@ -1367,15 +1384,15 @@ def _build_runnable(self, model: Any) -> Any: if self._device == "cpu": target = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): - relax_exec = tvm.relax.build(model, target) - runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu()) + self._executable = tvm.relax.build(model, target) + runnable = tvm.relax.VirtualMachine(self._executable, tvm.cpu()) elif self._device.startswith("cuda"): target = tvm.target.Target("cuda") with target: model = tvm.tir.transform.DefaultGPUSchedule()(model) with tvm.transform.PassContext(opt_level=3): - relax_exec = tvm.relax.build(model, target) - runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cuda()) + self._executable = tvm.relax.build(model, target) + runnable = tvm.relax.VirtualMachine(self._executable, tvm.cuda()) else: raise NotImplementedError("Unsupported device " + str(self._device)) return runnable @@ -1437,6 +1454,24 @@ def _device_enabled(self, device: str) -> bool: return tvm.cuda(dev_id).exist return False + def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: + """Export the runnable + + Parameters + ------- + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The runnable info. + """ + + export_path = folder.relpath("model.so") + self._executable.export_library(export_path) + return {"model": export_path} + @property def partition_func(self): raise NotImplementedError("partition_func is not implemented for " + str(self.__class__)) diff --git a/python/tvm/contrib/msc/core/tools/configer.py b/python/tvm/contrib/msc/core/tools/configer.py index c9ac6dd876b2..2c6789591721 100644 --- a/python/tvm/contrib/msc/core/tools/configer.py +++ b/python/tvm/contrib/msc/core/tools/configer.py @@ -45,10 +45,7 @@ def config(self, raw_config: dict = None) -> dict: config["tool_config"] = self.update_tool(raw_config) else: config["tool_config"] = self.config_tool() - if self.run_type: - config["run_type"] = self.run_type - if self.apply_once: - config["apply_once"] = self.apply_once + config.update(self.config_apply()) return config def config_tool(self) -> dict: @@ -95,13 +92,16 @@ def config_gym(self, gym_config: Union[dict, str]) -> dict: raise NotImplementedError("config_gym is not implemented in ToolConfiger") - @property - def run_type(self): - return "" + def config_apply(self) -> dict: + """Get the config fro apply - @property - def apply_once(self): - return False + Returns + ------- + config: dict + The apply config. + """ + + return {} @classmethod def tool_type(cls): diff --git a/python/tvm/contrib/msc/core/tools/distill/distiller.py b/python/tvm/contrib/msc/core/tools/distill/distiller.py index 7eee93cbc9e6..39e06b701bbe 100644 --- a/python/tvm/contrib/msc/core/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/core/tools/distill/distiller.py @@ -271,10 +271,8 @@ def tool_type(cls): return ToolType.DISTILLER +@msc_utils.register_tool class DefaultDistiller(BaseDistiller): @classmethod def tool_style(cls): return "default" - - -msc_utils.register_tool_cls(DefaultDistiller) diff --git a/python/tvm/contrib/msc/core/tools/distill/method.py b/python/tvm/contrib/msc/core/tools/distill/method.py index 0f3fd0fe4824..0fc80d1e30c9 100644 --- a/python/tvm/contrib/msc/core/tools/distill/method.py +++ b/python/tvm/contrib/msc/core/tools/distill/method.py @@ -25,6 +25,7 @@ from tvm.contrib.msc.core import utils as msc_utils +@msc_utils.register_tool_method class DistillMethod(object): """Default distill method""" @@ -68,5 +69,6 @@ def framework(cls): def tool_type(cls): return ToolType.DISTILLER - -msc_utils.register_tool_method(DistillMethod) + @classmethod + def method_style(cls): + return "default" diff --git a/python/tvm/contrib/msc/core/tools/execute.py b/python/tvm/contrib/msc/core/tools/execute.py index 7623de109e08..22cb52a60b6d 100644 --- a/python/tvm/contrib/msc/core/tools/execute.py +++ b/python/tvm/contrib/msc/core/tools/execute.py @@ -86,7 +86,7 @@ def create_tool(framework: str, tool_type: str, tag: str = "main", **config) -> """ tool_style = config.pop("tool_style") if "tool_style" in config else "default" - tool_cls = msc_utils.get_registered_tool_cls(framework, tool_type, tool_style) + tool_cls = msc_utils.get_registered_tool(framework, tool_type, tool_style) assert tool_cls, "Can not find tool class for {}:{} @ {}".format( tool_type, tool_style, framework ) diff --git a/python/tvm/contrib/msc/core/tools/prune/method.py b/python/tvm/contrib/msc/core/tools/prune/method.py index fd3abe8df42b..91322ae91fef 100644 --- a/python/tvm/contrib/msc/core/tools/prune/method.py +++ b/python/tvm/contrib/msc/core/tools/prune/method.py @@ -25,6 +25,7 @@ from tvm.contrib.msc.core import utils as msc_utils +@msc_utils.register_tool_method class PruneMethod(object): """Default prune method""" @@ -114,5 +115,6 @@ def framework(cls): def tool_type(cls): return ToolType.PRUNER - -msc_utils.register_tool_method(PruneMethod) + @classmethod + def method_style(cls): + return "default" diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index 515ea09e0145..9f20240cf218 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -541,10 +541,8 @@ def tool_type(cls): return ToolType.PRUNER +@msc_utils.register_tool class DefaultPruner(BasePruner): @classmethod def tool_style(cls): return "default" - - -msc_utils.register_tool_cls(DefaultPruner) diff --git a/python/tvm/contrib/msc/core/tools/quantize/method.py b/python/tvm/contrib/msc/core/tools/quantize/method.py index 970185826711..05d0711ea9fa 100644 --- a/python/tvm/contrib/msc/core/tools/quantize/method.py +++ b/python/tvm/contrib/msc/core/tools/quantize/method.py @@ -25,6 +25,7 @@ from tvm.contrib.msc.core import utils as msc_utils +@msc_utils.register_tool_method class QuantizeMethod(object): """Default quantize method""" @@ -468,5 +469,6 @@ def framework(cls): def tool_type(cls): return ToolType.QUANTIZER - -msc_utils.register_tool_method(QuantizeMethod) + @classmethod + def method_style(cls): + return "default" diff --git a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py index 8bf8242bb4b2..3d706002d6c6 100644 --- a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py @@ -254,10 +254,8 @@ def tool_type(cls): return ToolType.QUANTIZER +@msc_utils.register_tool class DefaultQuantizer(BaseQuantizer): @classmethod def tool_style(cls): return "default" - - -msc_utils.register_tool_cls(DefaultQuantizer) diff --git a/python/tvm/contrib/msc/core/tools/track/configer.py b/python/tvm/contrib/msc/core/tools/track/configer.py index fafb30d4842c..ef9c18c3f72e 100644 --- a/python/tvm/contrib/msc/core/tools/track/configer.py +++ b/python/tvm/contrib/msc/core/tools/track/configer.py @@ -25,9 +25,18 @@ class TrackConfiger(ToolConfiger): """Configer for track""" - @property - def apply_once(self): - return False + def config_apply(self) -> dict: + """Get the config fro apply + + Returns + ------- + config: dict + The apply config. + """ + + config = super().config_apply() + config.update({"apply_once": True}) + return config @classmethod def tool_type(cls): diff --git a/python/tvm/contrib/msc/core/tools/track/method.py b/python/tvm/contrib/msc/core/tools/track/method.py index 7d02456f4359..44d3813600e2 100644 --- a/python/tvm/contrib/msc/core/tools/track/method.py +++ b/python/tvm/contrib/msc/core/tools/track/method.py @@ -25,6 +25,7 @@ from tvm.contrib.msc.core import utils as msc_utils +@msc_utils.register_tool_method class TrackMethod(object): """Default track method""" @@ -95,6 +96,3 @@ def tool_type(cls): @classmethod def method_style(cls): return "default" - - -msc_utils.register_tool_method(TrackMethod) diff --git a/python/tvm/contrib/msc/core/tools/track/tracker.py b/python/tvm/contrib/msc/core/tools/track/tracker.py index bb60b9fe8b2d..510153a5c4e5 100644 --- a/python/tvm/contrib/msc/core/tools/track/tracker.py +++ b/python/tvm/contrib/msc/core/tools/track/tracker.py @@ -185,10 +185,8 @@ def tool_type(cls): return ToolType.TRACKER +@msc_utils.register_tool class DefaultTracker(BaseTracker): @classmethod def tool_style(cls): return "default" - - -msc_utils.register_tool_cls(DefaultTracker) diff --git a/python/tvm/contrib/msc/core/utils/expr.py b/python/tvm/contrib/msc/core/utils/expr.py index fa9f339a7524..b18e88888723 100644 --- a/python/tvm/contrib/msc/core/utils/expr.py +++ b/python/tvm/contrib/msc/core/utils/expr.py @@ -17,6 +17,7 @@ """tvm.contrib.msc.core.utils.expr""" import copy +from typing import Dict import tvm from tvm import relax @@ -44,6 +45,28 @@ def get_expr_name(expr: relax.Expr) -> str: return name +def make_span(kwargs: Dict[str, str], span: relax.Span = None) -> relax.Span: + """Change name to span + + Parameters + ---------- + kwargs: dict + The attrs in span. + span: relax.Span + The source span. + + Returns + ------- + span: relax.Span + The span. + """ + + span = span or relax.Span(tvm.ir.SourceName(""), 0, 0, 0, 0) + for k, v in kwargs.items(): + span = _ffi_api.SpanSetAttr(span, _ffi_api.ToAttrKey(k), v) + return span + + def set_expr_name(expr: relax.Expr, name: str): """Set the name for expr @@ -60,7 +83,7 @@ def set_expr_name(expr: relax.Expr, name: str): The expr with name. """ - expr.span = _ffi_api.SpanSetAttr(expr.span, _ffi_api.ToAttrKey("name"), name) + expr.span = make_span({"name": name}, expr.span) return expr diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py index 49912b4d041b..b1eb8fa8bfa1 100644 --- a/python/tvm/contrib/msc/core/utils/file.py +++ b/python/tvm/contrib/msc/core/utils/file.py @@ -110,6 +110,27 @@ def __exit__(self, exception_type, exception_value, traceback): def __del__(self): self.clean_up() + def finalize(self): + """Finalize the directory""" + + if not os.path.isdir(self._path): + return self._path + + def _remove_empty(path: str): + sub_paths = [os.path.join(path, f) for f in os.listdir(path)] + for s_path in sub_paths: + if not os.path.isdir(s_path): + continue + if len(os.listdir(s_path)) == 0: + shutil.rmtree(s_path) + else: + _remove_empty(s_path) + if len(os.listdir(path)) == 0: + shutil.rmtree(path) + return path + + return _remove_empty(self._path) + def clean_up(self): """Clean up the dir""" @@ -384,7 +405,7 @@ def to_abs_path(path: str, root_dir: MSCDirectory = None, keep_history: bool = T return root_dir.relpath(path, keep_history) -def pack_folder(path: str, style="tar"): +def pack_folder(path: str, style="tar.gz"): """Pack the folder Parameters @@ -401,7 +422,7 @@ def pack_folder(path: str, style="tar"): """ root = os.path.dirname(path) - if style == "tar": + if style == "tar.gz": cmd = "tar --exculde={0}.tar.gz -zcvf {0}.tar.gz {0} && rm -rf {0}".format(path) else: raise NotImplementedError("Pack style {} is not supported".format(style)) @@ -411,6 +432,7 @@ def pack_folder(path: str, style="tar"): else: retcode = subprocess.call(cmd, shell=True) assert retcode == 0, "Failed to pack the folder {}({}): {}".format(path, style, retcode) + return path + "." + style get_build_dir = partial(get_workspace_subdir, name="Build") diff --git a/python/tvm/contrib/msc/core/utils/log.py b/python/tvm/contrib/msc/core/utils/log.py index 916eb2468860..1422ad9a1bd0 100644 --- a/python/tvm/contrib/msc/core/utils/log.py +++ b/python/tvm/contrib/msc/core/utils/log.py @@ -135,3 +135,11 @@ def get_global_logger() -> logging.Logger: if not MSCMap.get(MSCKey.GLOBALE_LOGGER): MSCMap.set(MSCKey.GLOBALE_LOGGER, IOLogger()) return MSCMap.get(MSCKey.GLOBALE_LOGGER) + + +def remove_loggers(): + """Remove the logger handlers""" + + logger = MSCMap.get(MSCKey.GLOBALE_LOGGER) + if logger: + logger.handlers.clear() diff --git a/python/tvm/contrib/msc/core/utils/message.py b/python/tvm/contrib/msc/core/utils/message.py index 1479a99dd5db..d7b64ee22ea3 100644 --- a/python/tvm/contrib/msc/core/utils/message.py +++ b/python/tvm/contrib/msc/core/utils/message.py @@ -39,6 +39,7 @@ class MSCStage(object): OPTIMIZE = "optimize" COMPILE = "compile" SUMMARY = "summary" + EXPORT = "export" ALL = [ SETUP, PREPARE, @@ -51,6 +52,7 @@ class MSCStage(object): OPTIMIZE, COMPILE, SUMMARY, + EXPORT, ] @classmethod diff --git a/python/tvm/contrib/msc/core/utils/register.py b/python/tvm/contrib/msc/core/utils/register.py index ae7c8eac03b3..be82e1d0907a 100644 --- a/python/tvm/contrib/msc/core/utils/register.py +++ b/python/tvm/contrib/msc/core/utils/register.py @@ -25,13 +25,12 @@ class MSCRegistery: REGISTERY = {} MSC_FUNCS = "msc_funcs" - MSC_TOOLS_CLS = "msc_tools_cls" - MSC_TOOLS_METHOD = "msc_tools_method" + TOOL_CLASSES = "tool_classes" + TOOL_METHODS = "tool_methods" TOOL_CONFIGERS = "tool_configers" GYM_CONFIGERS = "gym_configers" GYM_CONTROLLERS = "gym_controllers" - GYM_AGENTS = "gym_agents" - GYM_ENVS = "gym_envs" + GYM_OBJECTS = "gym_objects" GYM_METHODS = "gym_agents_method" RUNNER_HOOKS = "runner_hooks" @@ -101,29 +100,25 @@ def get_registered_func(name: str, framework: str = MSCFramework.MSC): return funcs[framework].get(name) -def register_tool_cls(tool_cls: Any): +def register_tool(tool: Any): """Register a tool class. Parameters ---------- - tool_cls: class + tool: class The tool class to be registered. """ - tools_cls = MSCRegistery.get(MSCRegistery.MSC_TOOLS_CLS, {}) for key in ["framework", "tool_type", "tool_style"]: - assert hasattr(tool_cls, key), "{} should be given to register tool class".format(key) - if tool_cls.framework() not in tools_cls: - tools_cls[tool_cls.framework()] = {} - framework_tools = tools_cls[tool_cls.framework()] - if tool_cls.tool_type() not in framework_tools: - framework_tools[tool_cls.tool_type()] = {} - tools = framework_tools[tool_cls.tool_type()] - tools[tool_cls.tool_style()] = tool_cls - MSCRegistery.register(MSCRegistery.MSC_TOOLS_CLS, tools_cls) - - -def get_registered_tool_cls(framework: str, tool_type: str, tool_style: str) -> Any: + assert hasattr(tool, key), "{} should be given to register tool".format(key) + tools_classes = MSCRegistery.get(MSCRegistery.TOOL_CLASSES, {}) + col = tools_classes.setdefault(tool.framework(), {}).setdefault(tool.tool_type(), {}) + col[tool.tool_style()] = tool + MSCRegistery.register(MSCRegistery.TOOL_CLASSES, tools_classes) + return tool + + +def get_registered_tool(framework: str, tool_type: str, tool_style: str) -> Any: """Get the registered tool class. Parameters @@ -137,35 +132,32 @@ def get_registered_tool_cls(framework: str, tool_type: str, tool_style: str) -> Returns ------- - tool_cls: class + tool: class The registered tool class. """ - tools_cls = MSCRegistery.get(MSCRegistery.MSC_TOOLS_CLS, {}) + tools_classes = MSCRegistery.get(MSCRegistery.TOOL_CLASSES, {}) if tool_style == "all": - return tools_cls.get(framework, {}).get(tool_type, {}) - return tools_cls.get(framework, {}).get(tool_type, {}).get(tool_style) + return tools_classes.get(framework, {}).get(tool_type, {}) + return tools_classes.get(framework, {}).get(tool_type, {}).get(tool_style) -def register_tool_method(method_cls: Any, method_style: str = "default"): +def register_tool_method(method: Any): """Register a tool method. Parameters ---------- - method_cls: class + method: class The method class. - method_style: string - The style of the method. """ - tools_method = MSCRegistery.get(MSCRegistery.MSC_TOOLS_METHOD, {}) - for key in ["framework", "tool_type"]: - assert hasattr(method_cls, key), "{} should be given to register tool method".format(key) - if method_cls.framework() not in tools_method: - tools_method[method_cls.framework()] = {} - register_name = "{}.{}".format(method_cls.tool_type(), method_style) - tools_method[method_cls.framework()][register_name] = method_cls - MSCRegistery.register(MSCRegistery.MSC_TOOLS_METHOD, tools_method) + for key in ["framework", "tool_type", "method_style"]: + assert hasattr(method, key), "{} should be given to register tool method".format(key) + tool_methods = MSCRegistery.get(MSCRegistery.TOOL_METHODS, {}) + col = tool_methods.setdefault(method.framework(), {}).setdefault(method.tool_type(), {}) + col[method.method_style()] = method + MSCRegistery.register(MSCRegistery.TOOL_METHODS, tool_methods) + return method def get_registered_tool_method( @@ -188,9 +180,8 @@ def get_registered_tool_method( The method class. """ - tools_method = MSCRegistery.get(MSCRegistery.MSC_TOOLS_METHOD, {}) - register_name = "{}.{}".format(tool_type, method_style) - return tools_method.get(framework, {}).get(register_name) + tool_methods = MSCRegistery.get(MSCRegistery.TOOL_METHODS, {}) + return tool_methods.get(framework, {}).get(tool_type, {}).get(method_style) def register_tool_configer(configer: Any): @@ -240,10 +231,11 @@ def register_gym_configer(configer: Any): The configer class. """ - configers = MSCRegistery.get(MSCRegistery.GYM_CONFIGERS, {}) assert hasattr(configer, "config_type"), "config_type should be given to register configer" - configers[configer.config_type()] = configer - MSCRegistery.register(MSCRegistery.GYM_CONFIGERS, configers) + gym_configers = MSCRegistery.get(MSCRegistery.GYM_CONFIGERS, {}) + gym_configers[configer.config_type()] = configer + MSCRegistery.register(MSCRegistery.GYM_CONFIGERS, gym_configers) + return configer def get_registered_gym_configer(config_type: str) -> Any: @@ -260,8 +252,8 @@ def get_registered_gym_configer(config_type: str) -> Any: The configer class. """ - configers = MSCRegistery.get(MSCRegistery.GYM_CONFIGERS, {}) - return configers.get(config_type) + gym_configers = MSCRegistery.get(MSCRegistery.GYM_CONFIGERS, {}) + return gym_configers.get(config_type) def register_gym_controller(controller: Any): @@ -273,12 +265,13 @@ def register_gym_controller(controller: Any): The controller class. """ - controllers = MSCRegistery.get(MSCRegistery.GYM_CONTROLLERS, {}) assert hasattr( controller, "control_type" ), "control_type should be given to register controller" - controllers[controller.control_type()] = controller - MSCRegistery.register(MSCRegistery.GYM_CONTROLLERS, controllers) + gym_controllers = MSCRegistery.get(MSCRegistery.GYM_CONTROLLERS, {}) + gym_controllers[controller.control_type()] = controller + MSCRegistery.register(MSCRegistery.GYM_CONTROLLERS, gym_controllers) + return controller def get_registered_gym_controller(control_type: str) -> Any: @@ -295,74 +288,46 @@ def get_registered_gym_controller(control_type: str) -> Any: The controller class. """ - controllers = MSCRegistery.get(MSCRegistery.GYM_CONTROLLERS, {}) - return controllers.get(control_type) - - -def register_gym_agent(agent: Any): - """Register a gym agent. - - Parameters - ---------- - agent: class - The agent class. - """ - - agents = MSCRegistery.get(MSCRegistery.GYM_AGENTS, {}) - assert hasattr(agent, "agent_type"), "agent_type should be given to register agent" - agents[agent.agent_type()] = agent - MSCRegistery.register(MSCRegistery.GYM_AGENTS, agents) + gym_controllers = MSCRegistery.get(MSCRegistery.GYM_CONTROLLERS, {}) + return gym_controllers.get(control_type) -def get_registered_gym_agent(agent_type: str) -> Any: - """Get the registered agent. +def register_gym_object(obj: Any): + """Register a gym object. Parameters ---------- - agent_type: string - The type of agent. - - Returns - ------- - agent: class - The agent class. + obj: class + The object class. """ - agents = MSCRegistery.get(MSCRegistery.GYM_AGENTS, {}) - return agents.get(agent_type) + for key in ["role", "role_type"]: + assert hasattr(obj, key), "{} should be given to register gym object".format(key) + gym_objects = MSCRegistery.get(MSCRegistery.GYM_OBJECTS, {}) + col = gym_objects.setdefault(obj.role(), {}) + col[obj.role_type()] = obj + MSCRegistery.register(MSCRegistery.GYM_OBJECTS, gym_objects) + return obj -def register_gym_env(env: Any): - """Register a gym env. +def get_registered_gym_object(role: str, role_type: str) -> Any: + """Get the registered object. Parameters ---------- - env: class - The env class. - """ - - envs = MSCRegistery.get(MSCRegistery.GYM_ENVS, {}) - assert hasattr(env, "env_type"), "env_type should be given to register env" - envs[env.env_type()] = env - MSCRegistery.register(MSCRegistery.GYM_ENVS, envs) - - -def get_registered_gym_env(env_type: str) -> Any: - """Get the registered env. - - Parameters - ---------- - env_type: string - The type of agent. + role: string + The role. + role_type: string + The type of the role. Returns ------- - env: class - The agent class. + object: class + The object class. """ - envs = MSCRegistery.get(MSCRegistery.GYM_ENVS, {}) - return envs.get(env_type) + gym_objects = MSCRegistery.get(MSCRegistery.GYM_OBJECTS, {}) + return gym_objects.get(role, {}).get(role_type) def register_gym_method(method: Any): @@ -374,17 +339,22 @@ def register_gym_method(method: Any): The method class. """ - methods = MSCRegistery.get(MSCRegistery.GYM_METHODS, {}) - assert hasattr(method, "method_type"), "method_type should be given to register method" - methods[method.method_type()] = method - MSCRegistery.register(MSCRegistery.GYM_METHODS, methods) + for key in ["role", "method_type"]: + assert hasattr(method, key), "{} should be given to register gym method".format(key) + gym_methods = MSCRegistery.get(MSCRegistery.GYM_METHODS, {}) + col = gym_methods.setdefault(method.role(), {}) + col[method.method_type()] = method + MSCRegistery.register(MSCRegistery.GYM_METHODS, gym_methods) + return method -def get_registered_gym_method(method_type: str) -> Any: +def get_registered_gym_method(role: str, method_type: str) -> Any: """Get the registered gym method. Parameters ---------- + role: str + The role. method_type: str The type of method. @@ -394,8 +364,8 @@ def get_registered_gym_method(method_type: str) -> Any: The method class. """ - methods = MSCRegistery.get(MSCRegistery.GYM_METHODS, {}) - return methods.get(method_type) + gym_methods = MSCRegistery.get(MSCRegistery.GYM_METHODS, {}) + return gym_methods.get(role, {}).get(method_type) def register_runner_hook(hook: Any): @@ -407,10 +377,11 @@ def register_runner_hook(hook: Any): The hook class. """ - hooks = MSCRegistery.get(MSCRegistery.RUNNER_HOOKS, {}) assert hasattr(hook, "name"), "name should be given to register hook" + hooks = MSCRegistery.get(MSCRegistery.RUNNER_HOOKS, {}) hooks[hook.name()] = hook MSCRegistery.register(MSCRegistery.RUNNER_HOOKS, hooks) + return hook def get_registered_runner_hook(name: str) -> Any: diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py index 0385c6d94144..72f08ab19a41 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py @@ -39,6 +39,7 @@ def create(self, base_cls: BaseDistiller) -> BaseDistiller: The distiller class. """ + @msc_utils.register_tool class Distiller(base_cls): """Adaptive distiller for tensorflow""" @@ -50,6 +51,6 @@ def framework(cls): factory = TensorflowDistillerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py index 9b3d9d4326db..5a34f21ec430 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py @@ -39,6 +39,7 @@ def create(self, base_cls: BasePruner) -> BasePruner: The pruner class. """ + @msc_utils.register_tool class Pruner(base_cls): """Adaptive pruner for tensorflow""" @@ -50,6 +51,6 @@ def framework(cls): factory = TensorflowPrunerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py index dd6f2aac38d2..8ce05d270861 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py @@ -39,6 +39,7 @@ def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: The quantizer class. """ + @msc_utils.register_tool class Quantizer(base_cls): """Adaptive quantizer for tensorflow""" @@ -50,6 +51,6 @@ def framework(cls): factory = TensorflowQuantizerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/track/tracker.py b/python/tvm/contrib/msc/framework/tensorflow/tools/track/tracker.py index 7023322681c9..6ab3a7764af3 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/track/tracker.py +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/track/tracker.py @@ -39,6 +39,7 @@ def create(self, base_cls: BaseTracker) -> BaseTracker: The tracker class. """ + @msc_utils.register_tool class Tracker(base_cls): """Adaptive tracker for tensorflow""" @@ -50,6 +51,6 @@ def framework(cls): factory = TensorflowTrackerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py index d74a6a42461c..e38c5d7482a4 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py @@ -17,6 +17,7 @@ # pylint: disable=unused-import """tvm.contrib.msc.framework.tensorrt.runtime.runner""" +import os from typing import Any, List, Dict import tvm @@ -102,6 +103,28 @@ def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.arra return super()._generate_model(graphs, weights) + def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: + """Export the runnable + + Parameters + ------- + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The runnable info. + """ + + info = super().export_runnable(folder) + info["engines"] = {} + for graph in self._graphs: + engine_file = msc_utils.get_output_dir().relpath(graph.name + ".trt") + assert os.path.isfile(engine_file), "Missing engine file " + engine_file + info["engines"] = folder.copy(engine_file) + return info + @classmethod def target_transform(cls, mod: tvm.IRModule): """Transform the mod by target. diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py index bc9ead6dcc83..6ec99dbfe931 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py @@ -39,6 +39,7 @@ def create(self, base_cls: BaseDistiller) -> BaseDistiller: The distiller class. """ + @msc_utils.register_tool class Distiller(base_cls): """Adaptive distiller for tensorrt""" @@ -50,6 +51,6 @@ def framework(cls): factory = TensorRTDistillerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py index da591d9cebb6..418065480469 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py @@ -39,6 +39,7 @@ def create(self, base_cls: BasePruner) -> BasePruner: The pruner class. """ + @msc_utils.register_tool class Pruner(base_cls): """Adaptive pruner for tensorrt""" @@ -50,6 +51,6 @@ def framework(cls): factory = TensorRTPrunerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py index 0feb836d1350..982a37d74128 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py @@ -24,6 +24,7 @@ from tvm.contrib.msc.core import utils as msc_utils +@msc_utils.register_tool_method class TensorRTQuantizeMethod(QuantizeMethod): """Default quantize method for tensorrt""" @@ -144,6 +145,3 @@ def dequantize_normal( @classmethod def framework(cls): return MSCFramework.TENSORRT - - -msc_utils.register_tool_method(TensorRTQuantizeMethod) diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py index e2402e2dfa62..ca2d78c4273c 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py @@ -45,6 +45,7 @@ def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: The quantizer class. """ + @msc_utils.register_tool class Quantizer(base_cls): """Adaptive quantizer for tensorrt""" @@ -357,6 +358,6 @@ def framework(cls): factory = TensorRTQuantizerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/track/tracker.py b/python/tvm/contrib/msc/framework/tensorrt/tools/track/tracker.py index 10ae794ca056..fa59131ff48f 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/track/tracker.py +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/track/tracker.py @@ -42,6 +42,7 @@ def create(self, base_cls: BaseTracker) -> BaseTracker: The tracker class. """ + @msc_utils.register_tool class Tracker(base_cls): """Adaptive tracker for tensorrt""" @@ -154,6 +155,6 @@ def framework(cls): factory = TensorRTTrackerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py index 688cfd8b30b9..51cc2180581f 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py @@ -43,6 +43,7 @@ def create(self, base_cls: BaseDistiller) -> BaseDistiller: The distiller class. """ + @msc_utils.register_tool class Distiller(base_cls): """Adaptive distiller for torch""" @@ -139,6 +140,6 @@ def framework(cls): factory = TorchDistillerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/method.py b/python/tvm/contrib/msc/framework/torch/tools/distill/method.py index 7de3fdbbacaa..9d6956ae6f06 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/distill/method.py +++ b/python/tvm/contrib/msc/framework/torch/tools/distill/method.py @@ -25,6 +25,7 @@ from tvm.contrib.msc.core import utils as msc_utils +@msc_utils.register_tool_method class TorchDistillMethod(DistillMethod): """Default quantize method for torch""" @@ -111,6 +112,3 @@ def loss_lp_norm( @classmethod def framework(cls): return MSCFramework.TORCH - - -msc_utils.register_tool_method(TorchDistillMethod) diff --git a/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py index 4dfcf21dca55..9272a24b2eac 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py @@ -39,6 +39,7 @@ def create(self, base_cls: BasePruner) -> BasePruner: The pruner class. """ + @msc_utils.register_tool class Pruner(base_cls): """Adaptive pruner for torch""" @@ -50,6 +51,6 @@ def framework(cls): factory = TorchPrunerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py b/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py index 9b36d89b7b93..8efc0efa598e 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py +++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py @@ -55,6 +55,7 @@ def backward(ctx, grad_outputs): return wrapper +@msc_utils.register_tool_method class TorchQuantizeMethod(QuantizeMethod): """Default quantize method for torch""" @@ -264,6 +265,3 @@ def quantize_normal( @classmethod def framework(cls): return MSCFramework.TORCH - - -msc_utils.register_tool_method(TorchQuantizeMethod) diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py index 0e5c599b877a..a1359631ad06 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py @@ -39,6 +39,7 @@ def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: The quantizer class. """ + @msc_utils.register_tool class Quantizer(base_cls): """Adaptive quantizer for torch""" @@ -50,6 +51,6 @@ def framework(cls): factory = TorchQuantizerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/torch/tools/track/tracker.py b/python/tvm/contrib/msc/framework/torch/tools/track/tracker.py index 0fa065153bf5..8924b53cc583 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/track/tracker.py +++ b/python/tvm/contrib/msc/framework/torch/tools/track/tracker.py @@ -39,6 +39,7 @@ def create(self, base_cls: BaseTracker) -> BaseTracker: The tracker class. """ + @msc_utils.register_tool class Tracker(base_cls): """Adaptive tracker for torch""" @@ -50,6 +51,6 @@ def framework(cls): factory = TorchTrackerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py index ab52b8de99d2..b4f052f08dfe 100644 --- a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py @@ -57,6 +57,18 @@ def __call__(self, *inputs) -> List[tvm.nd.array]: class TVMRunner(ModelRunner): """Runner of Relax""" + def setup(self) -> dict: + """Setup the runner + + Returns + ------- + info: dict + The setup info. + """ + + self._executable = None + return super().setup() + def _build_runnable(self, model: Any) -> Any: """Build runnable object @@ -88,15 +100,15 @@ def _build_runnable(self, model: Any) -> Any: if self._device.startswith("cpu"): target = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): - relax_exec = tvm.relax.build(model, target) - runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu()) + self._executable = tvm.relax.build(model, target) + runnable = tvm.relax.VirtualMachine(self._executable, tvm.cpu()) elif self._device.startswith("cuda"): target = tvm.target.Target("cuda") with target: model = tvm.tir.transform.DefaultGPUSchedule()(model) with tvm.transform.PassContext(opt_level=3): - relax_exec = tvm.relax.build(model, target) - runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cuda()) + self._executable = tvm.relax.build(model, target) + runnable = tvm.relax.VirtualMachine(self._executable, tvm.cuda()) else: raise NotImplementedError("Unsupported device " + str(self._device)) return WrapRunnable(runnable) @@ -143,6 +155,24 @@ def _device_enabled(self, device: str) -> bool: return tvm.cuda(dev_id).exist return False + def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: + """Export the runnable + + Parameters + ------- + folder: MSCDirectory + The export folder. + + Returns + ------- + info: dict + The runnable info. + """ + + export_path = folder.relpath("model.so") + self._executable.export_library(export_path) + return {"model": export_path} + @property def codegen_func(self): return to_relax diff --git a/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py index 9cfc99dc1aef..8c42542d1b31 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py @@ -39,6 +39,7 @@ def create(self, base_cls: BaseDistiller) -> BaseDistiller: The distiller class. """ + @msc_utils.register_tool class Distiller(base_cls): """Adaptive distiller for tvm""" @@ -50,6 +51,6 @@ def framework(cls): factory = TVMDistillerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py index 198a6985466a..51d50fc7b861 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py @@ -39,6 +39,7 @@ def create(self, base_cls: BasePruner) -> BasePruner: The pruner class. """ + @msc_utils.register_tool class Pruner(base_cls): """Adaptive pruner for tvm""" @@ -50,6 +51,6 @@ def framework(cls): factory = TVMPrunerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py index 5a534991b93f..d56193d9f7c1 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py @@ -28,6 +28,7 @@ from tvm.contrib.msc.core import _ffi_api +@msc_utils.register_tool_method class TVMQuantizeMethod(QuantizeMethod): """Default quantize method for tvm""" @@ -200,6 +201,3 @@ def dequantize_normal( @classmethod def framework(cls): return MSCFramework.TVM - - -msc_utils.register_tool_method(TVMQuantizeMethod) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py index d4680b9088b3..173dc7c3d9e8 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py @@ -43,6 +43,7 @@ def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: The quantizer class. """ + @msc_utils.register_tool class Quantizer(base_cls): """Adaptive quantizer for tvm""" @@ -162,6 +163,6 @@ def framework(cls): factory = TVMQuantizerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py b/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py index 0054b7e77349..2bb0de02be22 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py @@ -43,6 +43,7 @@ def create(self, base_cls: BaseTracker) -> BaseTracker: The tracker class. """ + @msc_utils.register_tool class Tracker(base_cls): """Adaptive tracker for tvm""" @@ -153,6 +154,6 @@ def framework(cls): factory = TVMTrackerFactory() -tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") +tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") for tool in tools.values(): - msc_utils.register_tool_cls(factory.create(tool)) + factory.create(tool) diff --git a/python/tvm/contrib/msc/pipeline/config.py b/python/tvm/contrib/msc/pipeline/config.py index 16ff34f2eca6..b6d80fd42089 100644 --- a/python/tvm/contrib/msc/pipeline/config.py +++ b/python/tvm/contrib/msc/pipeline/config.py @@ -116,8 +116,8 @@ def create_config( baseline_type = baseline_type or model_type optimize_type = optimize_type or baseline_type compile_type = compile_type or optimize_type - if tools: - tools = [config_tool(t_type, t_config) for t_type, t_config in tools] + tools = tools or [] + tools = [config_tool(t_type, t_config) for t_type, t_config in tools] # basic config config = { "model_type": model_type, @@ -133,7 +133,8 @@ def create_config( } # config optimize - if tools: + opt_tools = [t for t in tools if support_tool(t, MSCStage.OPTIMIZE, optimize_type)] + if opt_tools: config[MSCStage.OPTIMIZE] = { "run_type": optimize_type, "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, @@ -145,6 +146,10 @@ def create_config( "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, } + # update config + if extra_config: + config = msc_utils.update_dict(config, extra_config) + # skip stages skip_config = skip_config or {} for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: @@ -164,7 +169,4 @@ def create_config( else: raise TypeError("Unexpected skip type " + str(skip_config[key])) - # update config - if extra_config: - config = msc_utils.update_dict(config, extra_config) return config diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py index c0b93569c843..e0f734af6cb5 100644 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ b/python/tvm/contrib/msc/pipeline/manager.py @@ -20,6 +20,7 @@ import os import time import json +import logging from typing import Dict, Any, Union, List import traceback import numpy as np @@ -68,7 +69,7 @@ def __init__( if root: def _from_root_mark(val): - if root and isinstance(val, str) and MSCKey.ROOT_MARK in val: + if isinstance(val, str) and MSCKey.ROOT_MARK in val: return val.replace(MSCKey.ROOT_MARK, root) return val @@ -77,7 +78,15 @@ def _from_root_mark(val): plugins = msc_utils.map_dict(plugins, _from_root_mark) # check stage - for stage in ["inputs", "outputs", "dataset", MSCStage.PREPARE, MSCStage.COMPILE]: + for stage in [ + "inputs", + "outputs", + "dataset", + MSCStage.PREPARE, + MSCStage.PARSE, + MSCStage.COMPILE, + MSCStage.EXPORT, + ]: config.setdefault(stage, {}) MSCMap.reset() @@ -162,13 +171,9 @@ def update_config(self, config: dict) -> dict: The updated config. """ - # update prepare and parse assert "inputs" in config, "inputs should be given to run manager" assert "outputs" in config, "outputs should be given to run manager" config, debug_levels = msc_utils.copy_dict(config), {} - for stage in [MSCStage.PREPARE, MSCStage.PARSE]: - if stage not in config: - config[stage] = {} config = self._get_runner_cls(self._model_type).update_config( MSCStage.PARSE, config, self._model ) @@ -186,6 +191,9 @@ def update_config(self, config: dict) -> dict: if config.get("tools"): config["tools"] = self._update_tools_config(config["tools"]) + # update export config + config[MSCStage.EXPORT].update({"inputs": config["inputs"], "outputs": config["outputs"]}) + def _set_debug_level(stage: str, sub_config: dict, default: int = None) -> dict: if "debug_level" in sub_config: debug_levels[stage] = sub_config["debug_level"] @@ -218,6 +226,7 @@ def _set_debug_level(stage: str, sub_config: dict, default: int = None) -> dict: MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE, + MSCStage.EXPORT, ] return {k: config[k] for k in ordered_keys if k in config}, debug_levels @@ -230,7 +239,7 @@ def run_pipe(self) -> dict: The pipeline report. """ - err_msg = None + err_msg, err_info = None, None try: self.prepare() self.parse() @@ -241,9 +250,11 @@ def run_pipe(self) -> dict: if MSCStage.COMPILE in self._config: self.compile() except Exception as exc: # pylint: disable=broad-exception-caught - err_msg = "Pipeline failed:{}\nTrace: {}".format(exc, traceback.format_exc()) - self.summary(err_msg) + err_msg = "Pipeline failed: " + str(exc) + err_info = traceback.format_exc() + self.summary(err_msg, err_info) self._logger.info(msc_utils.msg_block("SUMMARY", self._report, 0)) + self._workspace.finalize() return self._report def prepare(self) -> Dict[str, np.ndarray]: @@ -334,9 +345,12 @@ def parse(self) -> tvm.IRModule: msc_utils.time_stamp(MSCStage.PARSE) stage_config = self._config[MSCStage.PARSE] - use_cache = self._config.get("use_cache", True) - - cache_path = msc_utils.get_cache_dir().relpath("parsed_relax.json") if use_cache else None + if self._config.get("use_cache", True): + cache_path = ( + msc_utils.get_cache_dir().create_dir(MSCStage.PARSE).relpath("parsed_relax.json") + ) + else: + cache_path = None if cache_path and os.path.isfile(cache_path): with open(cache_path, "r") as f: self._relax_mod = tvm.ir.load_json(f.read()) @@ -447,13 +461,15 @@ def apply_tools(self, stage: str): self._logger.debug("Remove apply once tool %s", tool["tool_type"]) self._tools_config = self._tools_config[:-1] - def summary(self, err_msg=None): + def summary(self, err_msg=None, err_info: str = None): """Summary the pipeline. Parameters ---------- err_msg: str The error message. + err_info: str + The error info. Returns ------- @@ -463,7 +479,7 @@ def summary(self, err_msg=None): msc_utils.time_stamp(MSCStage.SUMMARY, False) if err_msg: - self._report.update({"success": False, "err_msg": err_msg}) + self._report.update({"success": False, "err_msg": err_msg, "err_info": err_info}) else: self._report["success"] = True self._report["duration"] = msc_utils.get_duration() @@ -490,29 +506,72 @@ def export(self, path: str = None, dump: bool = True) -> Union[str, dict]: folder, dump = msc_utils.msc_dir(path.replace(".tar.gz", ""), keep_history=False), True else: folder = msc_utils.msc_dir(path, keep_history=False) - if dump: - plugins = export_plugins(self._plugins, folder.create_dir("plugin")) - else: - plugins = self._plugins def _to_root_mark(val): if isinstance(val, str) and folder.path != val and folder.path in val: return val.replace(folder.path, MSCKey.ROOT_MARK) return val - pipeline = { - "model": self.export_model(folder.create_dir("model"), dump), - "config": self.export_config(folder, dump), - "plugins": plugins, - "root": folder.path, - } - pipeline = msc_utils.map_dict(pipeline, _to_root_mark) - if not dump: - return pipeline - with open(folder.relpath("pipeline.json"), "w") as f: - f.write(json.dumps(pipeline, indent=2)) + # export compiled + if self._compiled: + if not dump: + return self._runner.runnable + model = self._runner.export_runnable(folder) + if self._plugins: + plugin = self._plugins[self.compile_type] + model["plugins"] = plugin.copy_libs(folder.create_dir("plugins")) + model.update( + { + "device": self._runner.device, + "model_type": self.compile_type, + "abstract": self._runner.model_info, + } + ) + # save golden + num_golden = self._config[MSCStage.EXPORT].get("num_golden", 0) + if num_golden > 0: + saver_options = { + "input_names": [i[0] for i in self._config["inputs"]], + "output_names": self._config["outputs"], + } + batch_cnt, model["golden"] = 0, folder.create_dir("golden").path + with msc_utils.IODataSaver(model["golden"], saver_options) as saver: + for inputs in self._get_loader()(): + if batch_cnt >= num_golden: + break + batch_cnt = saver.save_batch(inputs, self._runner.run(inputs)) + model = msc_utils.map_dict(model, _to_root_mark) + with open(folder.relpath("model.json"), "w") as f: + f.write(json.dumps(model, indent=2)) + else: + if dump: + plugins = export_plugins(self._plugins, folder.create_dir("plugins")) + else: + plugins = self._plugins + + pipeline = { + "model": self.export_model(folder.create_dir("model"), dump), + "config": self.export_config(folder, dump), + "plugins": plugins, + "root": folder.path, + } + pipeline = msc_utils.map_dict(pipeline, _to_root_mark) + if not dump: + return pipeline + with open(folder.relpath("pipeline.json"), "w") as f: + f.write(json.dumps(pipeline, indent=2)) + # copy common files + if self._optimized or self._compiled: + stage = MSCStage.COMPILE if self._compiled else MSCStage.OPTIMIZE + msc_utils.get_visual_dir().copy(stage, folder.relpath("visualize")) + for log_h in self._logger.handlers: + if isinstance(log_h, logging.FileHandler): + folder.copy(log_h.baseFilename) + with open(folder.relpath("report.json"), "w") as f: + f.write(json.dumps(self._report, indent=2)) + folder.finalize() if path.endswith(".tar.gz"): - msc_utils.pack_folder(path.replace(".tar.gz", ""), "tar") + msc_utils.pack_folder(path.replace(".tar.gz", ""), "tar.gz") return path def export_model(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: @@ -531,8 +590,6 @@ def export_model(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any The exported model. """ - if self._compiled: - return self._runner._save_runnable(folder) if dump else self._runner.runnable if self._optimized: module = self._runner.export_module(folder) if not dump: @@ -543,7 +600,9 @@ def export_model(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any return {"model": path} if not dump: return self._model - return self._get_runner_cls(self._model_type).dump_nativate(self._model, folder) + return self._get_runner_cls(self._model_type).dump_nativate( + self._model, folder, **self._config[MSCStage.EXPORT] + ) def export_config(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> dict: """Export the config @@ -561,9 +620,6 @@ def export_config(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> di The updated config. """ - if self._compiled: - return {"model_info": self.runner.model_info} - # dump the dataloader def _save_dataset(name, info, dump: bool): loader, max_batch = info["loader"], info.get("max_batch", -1) @@ -631,6 +687,7 @@ def destory(self, keep_workspace: bool = False): self._runner.destory() if not keep_workspace: self._workspace.destory() + msc_utils.remove_loggers() def _create_runner( self, @@ -689,7 +746,7 @@ def _create_runner( runner.build(cache_dir=cache_dir) self._report["info"][stage + "_type"] = "{}({})".format(runner.framework, runner.device) if visualize: - runner.visualize(msc_utils.get_visual_dir().create_dir(stage)) + runner.visualize(msc_utils.get_visual_dir().create_dir(stage.split(".")[0])) if profile and "profile" in stage_config: self._report["profile"][stage] = self._profile_runner(runner, stage_config) if use_cache: @@ -725,7 +782,9 @@ def _apply_tool(self, tool: dict, stage: str) -> str: "run_type": tool.get("run_type", self._config[stage]["run_type"]), "run_config": self._config[stage]["run_config"], } - runner = self._create_runner(t_stage, stage_config, profile=False, use_cache=False) + runner = self._create_runner( + t_stage, stage_config, visualize=False, profile=False, use_cache=False + ) if "gym_configs" in tool: knowledge = None for idx, config in enumerate(tool["gym_configs"]): @@ -756,7 +815,10 @@ def _apply_tool(self, tool: dict, stage: str) -> str: self._logger.info("%sFound %d plan", gym_mark, len(plan)) return msc_utils.save_dict(plan, plan_file) msc_utils.time_stamp(t_stage + ".make_plan", False) - return runner.make_plan(tool_type, self._get_loader(tool_stage)) + plan_file = runner.make_plan(tool_type, self._get_loader(tool_stage)) + if tool.get("visualize", False): + runner.visualize(msc_utils.get_visual_dir().create_dir(stage)) + return plan_file def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: """Profile the runner. diff --git a/python/tvm/contrib/msc/pipeline/wrapper.py b/python/tvm/contrib/msc/pipeline/wrapper.py index c790b5ef27be..2b69034cab70 100644 --- a/python/tvm/contrib/msc/pipeline/wrapper.py +++ b/python/tvm/contrib/msc/pipeline/wrapper.py @@ -20,6 +20,7 @@ from typing import Any, Union, List from tvm.contrib.msc.core.tools.tool import BaseTool, ToolType +from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.core import utils as msc_utils from .manager import MSCManager @@ -37,8 +38,6 @@ class BaseWrapper(object): The config for pipeline plugins: dict The plugins for pipeline. - debug: bool - Whether to use debug mode. """ def __init__( @@ -47,14 +46,13 @@ def __init__( config: dict, workspace: str = "msc_workspace", plugins: dict = None, - debug: bool = False, ): self._meta_model = model self._optimized_model, self._compiled_model = None, None self._config = config self._plugins = plugins verbose = config.get("verbose", "info") - self._debug = True if verbose.startswith("debug") else debug + self._debug = verbose.startswith("debug") self._workspace = msc_utils.msc_dir(workspace, keep_history=self._debug) log_path = self._workspace.relpath("MSC_LOG", keep_history=False) self._config["logger"] = msc_utils.create_file_logger(verbose, log_path) @@ -92,9 +90,15 @@ def optimize(self, workspace: str = "Optimize"): self.logger.info("[Wrapper] Start optimize model") config = msc_utils.copy_dict(self._config) config["workspace"] = self._workspace.create_dir(workspace) + if MSCStage.OPTIMIZE not in config: + config[MSCStage.OPTIMIZE] = { + "run_type": self.model_type(), + "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, + } self._manager = MSCManager(self._meta_model, config, self._plugins, run_compile=False) - self._manager.run_pipe() - self._optimized_model = self._manager.get_runnable("runnable") + report = self._manager.run_pipe() + if report["success"]: + self._optimized_model = self._manager.get_runnable("runnable") return self def compile( @@ -118,8 +122,9 @@ def compile( pipeline = self.export(ckpt_path, dump=dump) pipeline["config"]["workspace"] = self._workspace.create_dir(workspace) self._manager = MSCManager(**pipeline) - self._manager.run_pipe() - self._compiled_model = self._manager.get_runnable("runnable") + report = self._manager.run_pipe() + if report["success"]: + self._compiled_model = self._manager.get_runnable("runnable") if not self._debug: shutil.rmtree(ckpt_path) else: @@ -127,8 +132,9 @@ def compile( config = msc_utils.copy_dict(self._config) config["workspace"] = self._workspace.create_dir(workspace) self._manager = MSCManager(self._meta_model, config, self._plugins) - self._manager.run_pipe() - self._compiled_model = self._manager.get_runnable("runnable") + report = self._manager.run_pipe() + if report["success"]: + self._compiled_model = self._manager.get_runnable("runnable") return self def export(self, path: str = "msc_export", dump: bool = True) -> Union[str, dict]: