diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index 74ee6cb46816..1243f910a4af 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit 74ee6cb46816267515c08eb78755d2b9b8db0bb4 +Subproject commit 1243f910a4afd49b7983c087e4f610b81e45f71c diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index 8d987b98f7f7..7d3a47310af1 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 8d987b98f7f7b9381097566643a7f53c99cf312d +Subproject commit 7d3a47310af1ac0795e0d8e8435e42c882c96a13 diff --git a/CMakeLists.txt b/CMakeLists.txt index b09ad9e54250..fcd670429faf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -886,6 +886,8 @@ if(USE_CUDA AND USE_CUTLASS) install(TARGETS fpA_intB_gemm EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) target_link_libraries(tvm PRIVATE fpA_intB_gemm) target_link_libraries(tvm_runtime PRIVATE fpA_intB_gemm) + target_link_libraries(tvm PRIVATE fpA_intB_gemm_tvm) + target_link_libraries(tvm_runtime PRIVATE fpA_intB_gemm_tvm) install(TARGETS flash_attn EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) target_link_libraries(tvm PRIVATE -Wl,--no-as-needed flash_attn) diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index bd3e3b116659..9ce27820b8f2 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -19,6 +19,9 @@ if(USE_CUDA AND USE_CUTLASS) tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc) list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC}) + set(FPA_INTB_GEMM_TVM_BINDING ON) + set(FPA_INTB_GEMM_TVM_HOME ${PROJECT_SOURCE_DIR}) + set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass) add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm) add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn) diff --git a/python/tvm/contrib/msc/core/gym/__init__.py b/python/tvm/contrib/msc/core/gym/__init__.py new file mode 100644 index 000000000000..e562cf495cee --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym""" + +from .environment import * +from .agent import * +from .control import * diff --git a/python/tvm/contrib/msc/core/gym/agent/__init__.py b/python/tvm/contrib/msc/core/gym/agent/__init__.py new file mode 100644 index 000000000000..e71ba5d7fbad --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/agent/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.agent""" + +from .method import * +from .search_agent import * diff --git a/python/tvm/contrib/msc/core/gym/agent/base_agent.py b/python/tvm/contrib/msc/core/gym/agent/base_agent.py new file mode 100644 index 000000000000..07c0b53597a7 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/agent/base_agent.py @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.base_agent""" + +import copy +import logging +from typing import Dict, Any, List, Tuple +from tvm.contrib.msc.core import utils as msc_utils + + +class BaseAgent(object): + """Basic Agent of MSC.Gym + + Parameters + ---------- + name: str + The name of agent. + workspace: MSCDirectory + The worksapce. + executors: dict + The executors of the agent. + options: dict + The extra options for the agent. + debug_level: int + The debug level. + verbose_task: int + The verbose interval task. + logger: logging.Logger + The logger + """ + + def __init__( + self, + name: str, + workspace: msc_utils.MSCDirectory, + executors: dict, + options: dict = None, + debug_level: int = 0, + logger: logging.Logger = None, + ): + self._name = name + self._workspace = workspace + self._executors = self._parse_executors(msc_utils.copy_dict(executors)) + self._options = options or {} + self._debug_level = debug_level + if logger: + self._logger = logger + else: + verbose = "debug" if debug_level > 0 else "info" + self._logger = msc_utils.create_file_logger(verbose, workspace.relpath("AGENT_LOG")) + self._logger.info( + msc_utils.msg_block("AGENT.SETUP({})".format(self.agent_type()), self.setup()) + ) + + def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, dict]]: + """Parse the executors + + Parameters + ---------- + executors_dict: dict + The given executors. + + Returns + ------- + executors_dict: dict + The parsed executors. + """ + + executors = {} + for name, raw_config in executors_dict.items(): + method_type = ( + raw_config.pop("method_type") if "method_type" in raw_config else "agent.default" + ) + method_cls = msc_utils.get_registered_gym_method(method_type) + assert "method" in raw_config, "method should be given to find agent method" + method_name, method = raw_config.pop("method"), None + if hasattr(method_cls, method_name): + method = getattr(method_cls, method_name) + if not method: + method = msc_utils.get_registered_func(method_name) + assert method, "Can not find method " + str(method_name) + executors[name] = (method_name, method, copy.deepcopy(raw_config)) + return executors + + def setup(self) -> dict: + """Setup the agent + + Returns + ------- + info: dict + The setup info. + """ + + self._knowledge = {"observations": [], "actions": [], "rewards": []} + return { + "name": self._name, + "workspace": self._workspace, + "executors": {k: "{}({})".format(v[0], v[2]) for k, v in self._executors.items()}, + "options": self._options, + "debug_level": self._debug_level, + } + + def init(self, max_task: int, baseline: Dict[str, Any]): + """Init the agent + + Parameters + ---------- + max_task: int + The max task for agent. + baseline: dict + The baseline of environment. + """ + + self._max_task = max_task + self._baseline = baseline + + def reset(self): + """Reset the agent""" + + self._knowledge = {"observations": [], "actions": [], "rewards": []} + + def choose_action(self, task_id: int, observation: Any, action_space: List[dict]) -> List[dict]: + """Choose action based on observation + + Parameters + ---------- + task_id: int + The current task id. + observation: + The current observation. + action_space: list + The possible action space + + Returns + ------- + actions: list + The actions for next task. + """ + + actions = self._choose_action(task_id, observation, action_space) + if task_id == len(self._knowledge["observations"]): + self._knowledge["observations"].append(observation) + self._knowledge["actions"].append(actions) + elif task_id == len(self._knowledge["observations"]) - 1: + self._knowledge["actions"][-1].extend(actions) + else: + raise TypeError( + "Step id should be either {0} or {0}-1, get {1}".format( + len(self._knowledge["observations"]), task_id + ) + ) + return actions + + def _choose_action( + self, task_id: int, observation: Any, action_space: List[dict] + ) -> List[dict]: + """Choose action based on observation + + Parameters + ---------- + task_id: int + The current task id. + observation: + The current observation. + action_space: list + The possible action space + + Returns + ------- + actions: list + The actions for next task. + """ + + raise NotImplementedError("_choose_action is not implemented in BaseAgent") + + def store(self, task_id: int, rewards: List[dict]) -> int: + """Store rewards + + Parameters + ---------- + task_id: int + The current task id. + rewards: list + The rewards for each action + + Returns + ------- + next_task: int + The next task id. + """ + + if task_id == len(self._knowledge["rewards"]): + self._knowledge["rewards"].append(rewards) + elif task_id == len(self._knowledge["rewards"]) - 1: + self._knowledge["rewards"][-1].extend(rewards) + else: + raise TypeError( + "Step id should be either {0} or {0}-1, get {1}".format( + len(self._knowledge["rewards"]), task_id + ) + ) + return self._store(task_id) + + def _store(self, task_id: int): + """Store rewards + + Parameters + ---------- + task_id: int + The current task id. + + Returns + ------- + next_task: int + The next task id. + """ + + return task_id + 1 + + def learn(self): + """Learn from knowledge + + Returns + ------- + actions: list + The learned actions. + rewards: list + The learned rewards. + """ + + self._logger.debug(msc_utils.msg_block("AGENT.LEARN", self._knowledge)) + return self._learn() + + def _learn(self): + """Learn from knowledge + + Returns + ------- + actions: list + The learned actions. + rewards: list + The learned rewards. + """ + + raise NotImplementedError("_learn is not implemented in BaseAgent") + + def destory(self): + """Destory the agent""" + + return None + + def _execute(self, name: str, *args, **kwargs) -> Any: + """Run executor + + Parameters + ---------- + name: str + The executor name. + args: list + The arguments for execute. + kwargs: dict + The key word arguments for execute. + + Returns + ------- + res: + The execute result. + """ + + assert name in self._executors, "Can not find {} in executors: {}".format( + name, self._executors.keys() + ) + _, method, config = self._executors[name] + kwargs.update({k: v for k, v in config.items() if k not in kwargs}) + return method(self, *args, **kwargs) + + def _evaluate(self, reward: dict) -> float: + """Evaluate a reward with baseline + + Parameters + ---------- + reward: dict + The reward for. + + Returns + ------- + score: float + The score of the reward. + """ + + return self._execute("evaluate", self._baseline, reward) + + @classmethod + def agent_type(cls): + return "base" + + +msc_utils.register_gym_agent(BaseAgent) diff --git a/python/tvm/contrib/msc/core/gym/agent/method.py b/python/tvm/contrib/msc/core/gym/agent/method.py new file mode 100644 index 000000000000..988fb23f69d6 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/agent/method.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.core.gym.agent.method""" + +from typing import Any +from tvm.contrib.msc.core import utils as msc_utils + + +class AgentMethod(object): + """Default prune method""" + + @classmethod + def evaluate_by_loss(cls, agent: Any, baseline: dict, reward: dict) -> float: + """Evaluate the raw loss + + Parameters + ---------- + agent: BaseAgent + The base agent. + baseline: dict + The baseline. + reward: dict + The reward. + + Returns + ------- + score: float + The score. + """ + + assert "loss" in reward, "loss should be given to evaluate loss" + return 1 / reward["loss"] + + @classmethod + def evaluate_by_thresh(cls, agent: Any, baseline: dict, reward: dict, thresh: float) -> float: + """Evaluate the raw loss + + Parameters + ---------- + agent: BaseAgent + The base agent. + baseline: dict + The baseline. + reward: dict + The reward. + thresh: float + The threshold + + Returns + ------- + score: float + The score. + """ + + assert "reward" in reward, "reward should be given to evaluate threshold" + if reward["reward"] >= thresh: + return thresh + return reward["reward"] + + @classmethod + def method_type(cls): + return "agent.default" + + +msc_utils.register_gym_method(AgentMethod) diff --git a/python/tvm/contrib/msc/core/gym/agent/search_agent.py b/python/tvm/contrib/msc/core/gym/agent/search_agent.py new file mode 100644 index 000000000000..8b9bc176ab47 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/agent/search_agent.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.search_agent""" + +from typing import Any, List +from tvm.contrib.msc.core import utils as msc_utils +from .base_agent import BaseAgent + + +class BaseSearchAgent(BaseAgent): + """Base Search Agent of MSC.Gym""" + + def setup(self) -> dict: + """Setup the tool + + Returns + ------- + info: dict + The setup info. + """ + + self._max_search = self._options.get("max_search", -1) + return super().setup() + + @classmethod + def agent_type(cls): + return "search.base" + + +class GridSearchAgent(BaseSearchAgent): + """GridSearch agent""" + + def _choose_action( + self, task_id: int, observation: Any, action_space: List[dict] + ) -> List[dict]: + """Choose action based on observation + + Parameters + ---------- + task_id: int + The current task id. + observation: + The current observation. + action_space: list + The possible action space + + Returns + ------- + actions: list + The actions for next task. + """ + + return action_space + + def _learn(self): + """Learn from knowledge + + Returns + ------- + actions: list + The learned actions. + rewards: list + The learned rewards. + """ + + best_actions = [None] * len(self._knowledge["actions"]) + best_rewards = [None] * len(self._knowledge["rewards"]) + idx = 0 + for actions, rewards in zip(self._knowledge["actions"], self._knowledge["rewards"]): + best_score = None + for action, reward in zip(actions, rewards): + score = self._evaluate(reward) + if best_score is None or score > best_score: + best_actions[idx] = action + best_rewards[idx] = reward + best_score = score + idx += 1 + return best_actions, best_rewards + + @classmethod + def agent_type(cls): + return "search.grid" + + +class BinarySearchAgent(BaseSearchAgent): + """BinarySearch agent""" + + def reset(self): + """Reset the agent""" + + self._ranges = [{"start": 0, "end": -1} for _ in range(self._max_task)] + super().reset() + + def _choose_action( + self, task_id: int, observation: Any, action_space: List[dict] + ) -> List[dict]: + """Choose action based on observation + + Parameters + ---------- + task_id: int + The current task id. + observation: + The current observation. + action_space: list + The possible action space + + Returns + ------- + actions: list + The actions for next task. + """ + + if self._ranges[task_id]["end"] == -1: + self._ranges[task_id]["end"] = len(action_space) + return [action_space[self._ranges[task_id]["start"]]] + pos = (self._ranges[task_id]["start"] + self._ranges[task_id]["end"]) / 2 + return [action_space[pos]] + + def _store(self, task_id: int): + """Store rewards + + Parameters + ---------- + task_id: int + The current task id. + + Returns + ------- + next_task: int + The next task id. + """ + + rewards = self._knowledge["rewards"][task_id] + start = self._ranges[task_id]["start"] + end = self._ranges[task_id]["end"] + if len(rewards) > 1: + if self._evaluate(rewards[-1]) > self._evaluate(rewards[-2]): + self._ranges[task_id]["end"] = (start + end) // 2 + else: + self._ranges[task_id]["start"] = (start + end) // 2 + if start - end <= 1: + return task_id + 1 + return task_id + + def _learn(self): + """Learn from knowledge + + Returns + ------- + actions: list + The learned actions. + rewards: list + The learned rewards. + """ + + actions = [a[-1] for a in self._knowledge["actions"]] + rewards = [r[-1] for r in self._knowledge["rewards"]] + return actions, rewards + + @classmethod + def agent_type(cls): + return "search.binary" + + +msc_utils.register_gym_agent(GridSearchAgent) diff --git a/python/tvm/contrib/msc/core/gym/control/__init__.py b/python/tvm/contrib/msc/core/gym/control/__init__.py new file mode 100644 index 000000000000..85ec3b050290 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/control/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.control""" + +from .controller import * +from .configer import * diff --git a/python/tvm/contrib/msc/core/gym/control/configer.py b/python/tvm/contrib/msc/core/gym/control/configer.py new file mode 100644 index 000000000000..00cb54cfd39a --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/control/configer.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.configer""" + +from tvm.contrib.msc.core import utils as msc_utils + + +class BaseConfiger(object): + """Configer for Gym + + Parameters + ---------- + stage: str + The stage for gym, should be in MSCStage. + """ + + def __init__(self, stage: str): + self._stage = stage + + def update(self, raw_config: dict) -> dict: + """Config the raw config + + Parameters + ---------- + raw_config: dict + The raw config. + + Returns + ------- + config: dict + The update config. + """ + + raise NotImplementedError("update is not implemented in BaseConfiger") + + +class DefaultConfiger(BaseConfiger): + """Default configer for gym""" + + def update(self, raw_config: dict) -> dict: + """Config the raw config + + Parameters + ---------- + raw_config: dict + The raw config. + + Returns + ------- + config: dict + The update config. + """ + + config = msc_utils.copy_dict(raw_config) + assert "env" in config and "agent" in config, "env and agent should be given to run gym" + if "env_type" not in config["env"]: + config["env"]["env_type"] = self._stage + ".default" + if "agent_type" not in config["agent"]: + config["agent"]["agent_type"] = "search.grid" + if "executors" not in config["env"]: + config["env"]["executors"] = {} + # update executors + env_executors = { + "reward_runner": {"method": "reward_compare_baseline"}, + "create_tasks": {"method": "tasks_tool_extract"}, + } + config["env"]["executors"].update( + {k: v for k, v in env_executors.items() if k not in config["env"]["executors"]} + ) + if "executors" not in config["agent"]: + config["agent"]["executors"] = {} + agent_executors = {"evaluate": {"method": "evaluate_by_loss"}} + config["agent"]["executors"].update( + {k: v for k, v in agent_executors.items() if k not in config["agent"]["executors"]} + ) + return config + + @classmethod + def config_type(cls): + return "default" + + +msc_utils.register_gym_configer(DefaultConfiger) diff --git a/python/tvm/contrib/msc/core/gym/control/controller.py b/python/tvm/contrib/msc/core/gym/control/controller.py new file mode 100644 index 000000000000..5716169e7678 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/control/controller.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.control.controller""" + +from typing import Dict, Any +from tvm.contrib.msc.core import utils as msc_utils +from .service import MainService, NodeService +from .namespace import GYMObject, GYMAction + + +class BaseController(object): + """Basic controller for optimize search + + Parameters + ---------- + workspace: MSCDirectory + The worksapce. + config: dict + The config for service. + is_master: bool + Whether the node is master node + """ + + def __init__( + self, + workspace: msc_utils.MSCDirectory, + config: Dict[str, Any], + is_main: bool = True, + ): + self._workspace = workspace + service_cls = MainService if is_main else NodeService + self._service = service_cls(self._workspace, **config) + + def run(self) -> dict: + """Run the controller + + Returns + ------- + report: dict + The run report. + """ + + self._service.init() + while not self._service.done: + self._service.reset() + while not self._service.iter_done: + self._service.execute(GYMObject.ENV, GYMAction.GET_STATE) + self._service.execute(GYMObject.AGENT, GYMAction.CHOOSE_ACTION) + self._service.execute(GYMObject.ENV, GYMAction.STEP) + self._service.execute(GYMObject.AGENT, GYMAction.STORE) + self._service.learn() + return self._service.summary() + + +def create_controller(stage: str, config: dict, extra_config: dict = None): + """Update the gym config + + Parameters + ---------- + stage: str + The stage for gym, should be in MSCStage. + config: dict + The raw config. + extra_config: dict + The extra config + + Returns + ------- + config: dict + The update config. + """ + + config_type = config.pop("config_type") if "config_type" in config else "default" + configer_cls = msc_utils.get_registered_gym_configer(config_type) + assert configer_cls, "Can not find configer for " + str(config_type) + config = configer_cls(stage).update(config) + if extra_config: + config = msc_utils.update_dict(config, extra_config) + if "control_type" in config: + control_type = config.pop("control_type") + else: + control_type = "default" + controller_cls = msc_utils.get_registered_gym_controller(control_type) + return controller_cls(msc_utils.get_gym_dir(), config) + + +class DefaultController(BaseController): + @classmethod + def control_type(cls): + return "default" + + +msc_utils.register_gym_controller(DefaultController) diff --git a/python/tvm/contrib/msc/core/gym/control/namespace.py b/python/tvm/contrib/msc/core/gym/control/namespace.py new file mode 100644 index 000000000000..b6db25785ffb --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/control/namespace.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.control.namespace""" + + +class GYMObject(object): + """Enum all gym objects""" + + BASE = "base" + ENV = "env" + AGENT = "agent" + SERVICE = "service" + + +class GYMAction(object): + """Enum all gym actions""" + + INIT = "init" + RESET = "reset" + GET_STATE = "get_state" + CHOOSE_ACTION = "choose_action" + STEP = "step" + STORE = "store" + LEARN = "learn" + SUMMARY = "summary" + CLEANUP = "cleanup" diff --git a/python/tvm/contrib/msc/core/gym/control/service.py b/python/tvm/contrib/msc/core/gym/control/service.py new file mode 100644 index 000000000000..6bb97def2545 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/control/service.py @@ -0,0 +1,810 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.control.service""" + +import json +import time +import copy +from typing import Dict, Any, List, Tuple +from multiprocessing import Manager +from functools import partial, reduce +import queue +import numpy as np + +from tvm.contrib.msc.core import utils as msc_utils +from .worker import BaseWorker, WorkerFactory +from .namespace import GYMObject, GYMAction + + +def _send_message(msg_queue: queue.Queue, header: str, body: dict, header_type: str = "message"): + """Send the message to queue + + Parameters + ---------- + msg_queue: Queue + The message queue. + header: str + The header of message. + body: dict + The message body. + header_type: str + The header type + """ + + msg_queue.put(json.dumps({header_type: header, "body": body})) + + +def _wait_message( + msg_queue: queue.Queue, + header: str, + checker: callable = None, + wait_time: int = 2, + max_retry: int = -1, + header_type: str = "message", +) -> dict: + """Wait until valid message + + Parameters + ---------- + msg_queue: Queue + The message queue. + header: str + The header of message. + checker: callable + The checker for the message. + wait_time: int + The wait time between retry in second. + max_retry: int + The max retry time. + header_type: str + The header type + + Returns + ------- + message: dict + The message body + """ + + def _check_message(message: dict, checker: callable = None) -> bool: + """Check the message + + Parameters + ---------- + message: dict + The message. + checker: callable + The checker for the message. + + Returns + ------- + pass: bool + Whether the message pass. + """ + + if "body" not in message: + return False + if checker and not checker(message["body"]): + return False + return True + + try_cnt = 0 + while True: + if try_cnt >= max_retry > 0: + break + info = msg_queue.get() + message = json.loads(info) + if message.get(header_type, "") == header and _check_message(message, checker): + return message["body"] + try_cnt += 1 + msg_queue.put(info) + time.sleep(wait_time) + return None + + +send_request = partial(_send_message, header_type="request_header") +send_response = partial(_send_message, header_type="response_header") +wait_request = partial(_wait_message, header_type="request_header") +wait_response = partial(_wait_message, header_type="response_header") + + +class GatherMode(object): + """Enum all gather mode""" + + PARALLEL = "parallel" + REDUCE_SUM = "reduce_sum" + REDUCE_MEAN = "reduce_mean" + FIRST = "first" + + +class BaseService(object): + """Basic service for gym + + Parameters + ---------- + workspace: MSCDirectory + The worksapce. + env: dict + The environment config. + agent: dict + The agent config + tasks: list + The tasks on the node. + world_size: int + The world size. + max_iter: int + The max seatch iter. + record_step: int + The record step. + debug_level: int + The debug level + """ + + def __init__( + self, + workspace: msc_utils.MSCDirectory, + env: Dict[str, Any], + agent: Dict[str, Any], + tasks: List[str] = None, + dist_manager: Manager = None, + world_size: int = 1, + max_iter: int = 1, + record_step: int = 5, + debug_level: int = 0, + ): + self._workspace = workspace + tasks = tasks or [GYMObject.ENV + ":0", GYMObject.AGENT + ":0"] + verbose = "debug" if debug_level > 0 else "info" + self._logger = msc_utils.create_file_logger(verbose, self._workspace.relpath("SERVICE_LOG")) + + def _create_workers(config: dict, obj_type: str) -> List[BaseWorker]: + if "debug_level" not in config: + config["debug_level"] = debug_level + if "logger" not in config: + config["logger"] = self._logger + return [ + WorkerFactory.create(t, workspace, config) for t in tasks if t.startswith(obj_type) + ] + + self._env_workers = _create_workers(env, GYMObject.ENV) + self._agent_workers = _create_workers(agent, GYMObject.AGENT) + self._dist_manager = dist_manager + self._world_size = world_size + self._max_iter = max_iter + self._record_step = record_step + self._debug_level = debug_level + self._logger.info( + msc_utils.msg_block("SERVICE.SETUP({})".format(self.service_type), self.setup()) + ) + + def setup(self) -> dict: + """Setup the tool + + Returns + ------- + info: dict + The setup info. + """ + + if self._world_size > 1: + assert self._dist_manager, "dist manager should be given for distributed service" + self._request_queue = self._dist_manager.get_request_queue() + self._response_queue = self._dist_manager.get_response_queue() + self._world_id, self._env_world_ids, self._agent_world_ids = self._connect() + else: + self._request_queue = queue.Queue() + self._response_queue = queue.Queue() + self._world_id = 0 + self._env_world_ids = [w.worker_id for w in self._env_workers] + self._agent_world_ids = [w.worker_id for w in self._agent_workers] + return { + "workspace": self._workspace, + "world_id": self._world_id, + "world_size": self._world_size, + "env_worker_ids": self._get_worker_ids(GYMObject.ENV), + "env_world_ids": self._env_world_ids, + "agent_worker_ids": self._get_worker_ids(GYMObject.AGENT), + "agent_world_ids": self._agent_world_ids, + "max_iter": self._max_iter, + "record_step": self._record_step, + "debug_level": self._debug_level, + } + + def init(self): + self._logger.info("SERVICE Init") + self._iter_id, self._done = 0, False + self._max_task = 0 + self._task_id, self._states = 0, [] + self._iter_done = False + self.execute(GYMObject.ENV, GYMAction.INIT) + self.execute(GYMObject.AGENT, GYMAction.INIT) + + def reset(self): + self._task_id, self._states = 0, [] + self._iter_done = False + self._logger.info("SERVICE Reset %d/%d th iter", self._iter_id, self._max_iter) + self.execute(GYMObject.AGENT, GYMAction.RESET) + self.execute(GYMObject.ENV, GYMAction.RESET) + + def learn(self): + self.execute(GYMObject.AGENT, GYMAction.LEARN) + if self._iter_done: + self._iter_id += 1 + if self._iter_id >= self._max_iter: + self._done = True + + def summary(self): + self._logger.info("SERVICE Summary after %d iters", self._max_iter) + self.execute(GYMObject.ENV, GYMAction.SUMMARY) + plan = self._states[-1]["response"]["plan"] + self.execute(GYMObject.ENV, GYMAction.CLEANUP) + self.execute(GYMObject.AGENT, GYMAction.CLEANUP) + return plan + + def execute(self, obj_type: str, act_type: str): + """Execute the service + + Parameters + ---------- + obj_type: str + The object type, should be one of GYMObject. + act_type: str + The action type, should be one of GYMAction. + """ + + self._states.append( + { + "task_id": self._task_id, + "msg_key": self._to_msg_key(obj_type, act_type), + "response": self._execute(obj_type, act_type), + } + ) + + def _execute(self, obj_type: str, act_type: str) -> dict: + """Execute the service + + Parameters + ---------- + obj_type: str + The object type, should be one of GYMObject. + act_type: str + The action type, should be one of GYMAction. + + Returns + ------- + state: dict + The state after the execute. + """ + + raise NotImplementedError("_execute is not implemented in BaseService") + + def _send_request(self, msg_key: str, body: dict): + """Send request + + Parameters + ---------- + msg_key: str + The header of message. + body: dict + The message body. + """ + + send_request(self._request_queue, msg_key, body) + + def _send_response(self, msg_key: str, body: dict): + """Send request + + Parameters + ---------- + msg_key: str + The header of message. + body: dict + The message body. + """ + + send_response(self._response_queue, msg_key, body) + + def _wait_request( + self, + msg_key: str, + checker: callable = None, + wait_time: int = 2, + max_retry: int = -1, + ) -> dict: + """Wait request + + Parameters + ---------- + msg_key: str + The header of message. + checker: callable + The checker for the message. + wait_time: int + The wait time between retry in second. + max_retry: int + The max retry time. + """ + + return wait_request(self._request_queue, msg_key, checker, wait_time, max_retry) + + def _wait_response( + self, + msg_key: str, + checker: callable = None, + wait_time: int = 2, + max_retry: int = -1, + ) -> dict: + """Wait response + + Parameters + ---------- + msg_key: str + The header of message. + checker: callable + The checker for the message. + wait_time: int + The wait time between retry in second. + max_retry: int + The max retry time. + """ + + return wait_request(self._response_queue, msg_key, checker, wait_time, max_retry) + + def _process_request(self, msg_key: str) -> dict: + """Process the request according to msg_key + + Parameters + ---------- + msg_key: str + The header of message. + + Returns + ------- + responses: dict + The responses of wrokers. + """ + + obj_type, act_type = self._from_msg_key(msg_key) + workers = {w.worker_id: w for w in self._get_workers(obj_type)} + requests = self._wait_request(msg_key) + if act_type in (GYMAction.INIT, GYMAction.RESET): + mark = "I[{}/{}] {}.{}".format(self._iter_id, self._max_iter, obj_type, act_type) + else: + mark = "I[{}/{}].T[{}/{}] {}.{}".format( + self._iter_id, self._max_iter, self._task_id, self._max_task, obj_type, act_type + ) + requests = {int(k): v for k, v in requests.items()} + responses = {} + for w_id, worker in workers.items(): + responses[w_id] = worker.execute(act_type, **requests[w_id]) + info = { + "requests": {workers[w].name: r for w, r in requests.items()}, + "responses": {workers[w].name: r for w, r in responses.items()}, + } + self._logger.info(msc_utils.msg_table(mark, info)) + return responses + + def _process_response(self, msg_key: str, response: dict): + """Update reponse + + Parameters + ---------- + msg_key: str + The header of message. + response: dict + The response. + + Returns + ------- + response: dict + The updated response. + """ + + obj_type, act_type = self._from_msg_key(msg_key) + if obj_type == GYMObject.ENV and act_type == GYMAction.INIT: + self._max_task = response["max_task"] + if obj_type == GYMObject.AGENT and act_type == GYMAction.STORE: + self._task_id = response["next_task"] + if self._task_id >= self._max_task: + self._iter_done = True + return response + + def _to_msg_key(self, obj_type: str, act_type: str) -> str: + """Create message key base on types + + Parameters + ---------- + obj_type: str + The object type, should be one of GYMObject. + act_type: str + The action type, should be one of GYMAction. + + Returns + ------- + key: str + The message key. + """ + + return "{}-s-{}".format(obj_type, act_type) + + def _from_msg_key(self, msg_key: str) -> Tuple[str, str]: + """Get obj_type and act_type from message key + + Parameters + ---------- + msg_key: str + The message key. + + Returns + ------- + obj_type: str + The object type, should be one of GYMObject. + act_type: str + The action type, should be one of GYMAction. + """ + + return msg_key.split("-s-") + + def _get_workers(self, obj_type: str) -> List[BaseWorker]: + """Get workers according to obj_type + + Parameters + ---------- + obj_type: str + The object type, should be one of GYMObject. + + Returns + ------- + workers: list + The workers. + """ + + if obj_type == GYMObject.ENV: + return self._env_workers + if obj_type == GYMObject.AGENT: + return self._agent_workers + return [] + + def _get_worker_ids(self, obj_type: str) -> List[int]: + """Get worker ids according to obj_type + + Parameters + ---------- + obj_type: str + The object type, should be one of GYMObject. + + Returns + ------- + worker_ids: list + The worker ids. + """ + + return [w.worker_id for w in self._get_workers(obj_type)] + + def _get_world_ids(self, obj_type: str) -> List[int]: + """Get world ids according to obj_type + + Parameters + obj_type: str + The object type, should be one of GYMObject. + + Returns + ------- + world_ids: list + The world ids. + """ + + if obj_type == GYMObject.ENV: + return self._env_world_ids + if obj_type == GYMObject.AGENT: + return self._agent_world_ids + return [] + + @property + def done(self): + return self._done + + @property + def iter_done(self): + return self._iter_done + + @property + def service_type(self): + return "base" + + +class MainService(BaseService): + """Main service for gym""" + + def _connect(self): + msg_key = self._to_msg_key(GYMObject.SERVICE, GYMAction.SETUP) + env_world_ids = self._get_worker_ids(GYMObject.ENV) + agent_world_ids = self._get_worker_ids(GYMObject.AGENT) + # send world_id and get env/agent ids + barrier = self._world_size - 1 + + def _check_response(body): + return all(k in body for k in ["env_worker_ids", "agent_worker_ids"]) + + for i in range(barrier): + self._send_request(msg_key, {"world_id": i + 1}) + while barrier > 0: + info = self._wait_response(msg_key, _check_response) + if info: + env_world_ids.extend(info["env_world_ids"]) + agent_world_ids.extend(info["agent_world_ids"]) + barrier -= 1 + + self._synchronize_feedback( + msg_key, env_world_ids=env_world_ids, agent_world_ids=agent_world_ids + ) + return 0, env_world_ids, agent_world_ids + + def _execute(self, obj_type: str, act_type: str) -> dict: + """Execute the service + + Parameters + ---------- + obj_type: str + The object type, should be one of GYMObject. + act_type: str + The action type, should be one of GYMAction. + + Returns + ------- + state: dict + The state after the execute. + """ + + world_ids = self._get_worker_ids(obj_type) + tasks = {i: self._create_task(obj_type, act_type, i) for i in world_ids} + msg_key = self._to_msg_key(obj_type, act_type) + response = self._synchronize_request(msg_key, tasks) + response = self._process_response(msg_key, response) + self._synchronize_feedback(msg_key, **response) + return response + + def _synchronize_request( + self, + msg_key: str, + requests: List[dict], + checker: callable = None, + wait_time: int = 2, + max_retry: int = -1, + ) -> dict: + """Send requests to workers and gather response + + Parameters + ---------- + msg_key: str + The header of message. + requests: list + The requests + checker: callable + The checker for the response. + wait_time: int + The wait time between retry in second. + max_retry: int + The max retry time. + + Returns + ------- + response: dict + The gathered response. + """ + + responses = {} + barrier = self._world_size + for _ in range(barrier): + self._send_request(msg_key, requests) + responses.update(self._process_request(msg_key)) + barrier -= 1 + while barrier > 0: + info = self._wait_response(msg_key, checker, wait_time, max_retry) + if info: + info = {int(k): v for k, v in info.items()} + responses.update(info) + barrier -= 1 + responses = [responses[i] for i in sorted(responses)] + gathered_response = {} + for key in responses[0]: + if key in ("action", "reward"): + gather_mode = GatherMode.PARALLEL + else: + gather_mode = GatherMode.FIRST + gathered_response[key] = self._gather_values([r[key] for r in responses], gather_mode) + return gathered_response + + def _synchronize_feedback(self, msg_key: str, **feedback: dict): + """Broadcast feedback to workers + + Parameters + ---------- + msg_key: str + The header of message. + feedback: dict + The feedback body + """ + + def _check_feedback(body): + return body.get("feedback_receive", False) + + barrier = self._world_size - 1 + for _ in range(barrier): + self._send_request(msg_key, {"feedback_send": True, **feedback}) + while barrier > 0: + info = self._wait_response(msg_key, _check_feedback) + if info: + barrier -= 1 + + def _create_task(self, obj_type: str, act_type: str, worker_id: int) -> dict: + """Create message key base on types + + Parameters + ---------- + obj_type: str + The object type, should be one of GYMObject. + act_type: str + The action type, should be one of GYMAction. + worker_id: int + The worker id. + + Returns + ------- + config: dict + The config for the worker.execute. + """ + + if not self._states: + config = {} + else: + config = copy.deepcopy(self._states[-1]["response"]) + if obj_type == GYMObject.ENV and act_type == GYMAction.GET_STATE: + config["task_id"] = self._task_id + if obj_type == GYMObject.ENV and act_type == GYMAction.STEP: + config["actions"] = self._map_values(config["actions"], obj_type, worker_id) + config["task_id"] = self._task_id + elif obj_type == GYMObject.AGENT and act_type in (GYMAction.CHOOSE_ACTION, GYMAction.STORE): + config["task_id"] = self._task_id + return config + + def _map_values(self, values: List[Any], obj_type: str, worker_id: int) -> List[Any]: + """Map the values for worker + + Parameters + ---------- + values: list + The global values, + obj_type: str + The object type, should be one of GYMObject. + worker_id: int + The worker id. + + Returns + ------- + values: list + The values for the worker. + """ + + world_ids = self._get_world_ids(obj_type) + tile_size = len(values) // len(world_ids) + if len(values) % len(world_ids) != 0: + tile_size += 1 + worker_idx = world_ids.index(worker_id) + start = worker_idx * tile_size + end = min((worker_idx + 1) * tile_size, len(values)) + return values[start:end] + + def _gather_values(self, values: List[Any], gather_mode: str) -> Any: + """Gather the values + + Parameters + ---------- + values: list + The global values, + gather_mode: str + The gather mode should be in GatherMode. + + Returns + ------- + value: + The gathered value. + """ + + if gather_mode == GatherMode.FIRST or len(values) == 1: + return values[0] + if gather_mode == GatherMode.PARALLEL: + return values + if gather_mode in (GatherMode.REDUCE_MEAN, GatherMode.REDUCE_SUM): + if all(msc_utils.MSCArray.is_array(v) for v in values): + value_sum = np.array([msc_utils.cast_array(v) for v in values]).sum(axis=1) + else: + value_sum = reduce(lambda x, y: x + y, values) + if gather_mode == GatherMode.REDUCE_SUM: + return value_sum + return value_sum / len(values) + raise NotImplementedError("Gather mode {} is not supported") + + @property + def service_type(self): + return "main" + + +class NodeService(BaseService): + """Normal service for gym""" + + def _connect(self): + msg_key = self._to_msg_key(GYMObject.SERVICE, GYMAction.SETUP) + env_worker_ids = self._get_worker_ids(GYMObject.ENV) + agent_worker_ids = self._get_worker_ids(GYMObject.AGENT) + + def _check_request(body): + return "world_id" in body + + info = self._wait_request(msg_key, _check_request) + world_id = info["world_id"] + self._send_response( + msg_key, {"env_worker_ids": env_worker_ids, "agent_worker_ids": agent_worker_ids} + ) + info = self._feedback(msg_key) + return world_id, info["env_world_ids"], info["agent_world_ids"] + + def _feedback(self, msg_key: str) -> dict: + """Send feed back to main service + + Parameters + ---------- + msg_key: str + The header of message. + + Returns + ------- + response: dict + The recived feedback. + """ + + def _check_feedback(body): + return body.get("feedback_send", False) + + response = self._wait_request(msg_key, _check_feedback) + self._send_response(msg_key, {"feedback_receive": True}) + response = self._process_response(msg_key, response) + return response + + def _execute(self, obj_type: str, act_type: str) -> dict: + """Execute the service + + Parameters + ---------- + obj_type: str + The object type, should be one of GYMObject. + act_type: str + The action type, should be one of GYMAction. + + Returns + ------- + state: dict + The state after the execute. + """ + + msg_key = self._to_msg_key(obj_type, act_type) + info = self._process_request(msg_key) + self._send_response(msg_key, info) + return self._feedback(msg_key) + + @property + def service_type(self): + return "node" diff --git a/python/tvm/contrib/msc/core/gym/control/worker.py b/python/tvm/contrib/msc/core/gym/control/worker.py new file mode 100644 index 000000000000..7ccfb5da38e2 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/control/worker.py @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.control.worker""" + +from typing import Any +from tvm.contrib.msc.core import utils as msc_utils +from .namespace import GYMObject, GYMAction + + +class BaseWorker(object): + """Basic worker for gym + + Parameters + ---------- + name: str + The worker name. + workspace: MSCDirectory + The worksapce. + worker_id: int + The worker_id. + worker_cls: class + The worker class. + worker_config: dict + The worker config. + """ + + def __init__( + self, + name: str, + workspace: msc_utils.MSCDirectory, + worker_id: int, + worker_cls: Any, + worker_config: dict, + ): + self._name = name + self._worker_id = worker_id + debug_level = worker_config.get("debug_level", 0) + if "logger" not in worker_config: + verbose = "debug" if debug_level > 0 else "info" + worker_config["logger"] = msc_utils.create_file_logger( + verbose, workspace.relpath("{}.{}_LOG".format(self.obj_type.upper(), worker_id)) + ) + if "workspace" not in worker_config: + worker_config["workspace"] = workspace + worker_config["name"] = name + self._worker_impl = worker_cls(**worker_config) + + def __str__(self): + return "<{}>: {}({})".format(self.obj_type, self._name, self._worker_id) + + def execute(self, act_type: str, **kwargs) -> Any: + """Execute the worker + + Parameters + ---------- + act_type: str + The action type, should be one of GYMAction. + kwargs: dict + The kwargs for execute. + + Returns + ------- + response: dict + The execute result. + """ + + raise NotImplementedError("execute is not implemented in BaseWorker") + + @property + def obj_type(self): + return GYMObject.BASE + + @property + def name(self): + return self._name + + @property + def worker_id(self): + return self._worker_id + + +class EnvWorker(BaseWorker): + """Env worker for gym""" + + def execute(self, act_type: str, **kwargs) -> Any: + """Execute the worker + + Parameters + ---------- + act_type: str + The action type, should be one of GYMAction. + kwargs: dict + The kwargs for execute. + + Returns + ------- + response: dict + The execute result. + """ + + response = {} + if act_type == GYMAction.INIT: + max_task, baseline = self._worker_impl.init() + response.update({"max_task": max_task, "baseline": baseline}) + elif act_type == GYMAction.RESET: + self._worker_impl.reset() + elif act_type == GYMAction.GET_STATE: + observation, action_space = self._worker_impl.get_state(kwargs["task_id"]) + response.update({"observation": observation, "action_space": action_space}) + elif act_type == GYMAction.STEP: + rewards = self._worker_impl.step(**kwargs) + response.update({"rewards": rewards}) + elif act_type == GYMAction.SUMMARY: + plan = self._worker_impl.summary(**kwargs) + response.update({"plan": plan}) + elif act_type == GYMAction.CLEANUP: + self._worker_impl.destory() + return response + + @property + def obj_type(self): + return GYMObject.ENV + + +class AgentWorker(BaseWorker): + """Env worker for gym""" + + def execute(self, act_type: str, **kwargs) -> Any: + """Execute the worker + + Parameters + ---------- + act_type: str + The action type, should be one of GYMAction. + kwargs: dict + The kwargs for execute. + + Returns + ------- + response: dict + The execute result. + """ + + response = {} + if act_type == GYMAction.INIT: + self._worker_impl.init(**kwargs) + elif act_type == GYMAction.RESET: + self._worker_impl.reset() + elif act_type == GYMAction.CHOOSE_ACTION: + actions = self._worker_impl.choose_action(**kwargs) + response.update({"actions": actions}) + elif act_type == GYMAction.STORE: + next_task = self._worker_impl.store(**kwargs) + response.update({"next_task": next_task}) + elif act_type == GYMAction.LEARN: + actions, rewards = self._worker_impl.learn() + response.update({"actions": actions, "rewards": rewards}) + elif act_type == GYMAction.CLEANUP: + self._worker_impl.destory() + return response + + @property + def obj_type(self): + return GYMObject.AGENT + + +class WorkerFactory(object): + """The Factory for workers""" + + @classmethod + def create(cls, name: str, workspace: msc_utils.MSCDirectory, config: dict) -> BaseWorker: + """Create worker + + Parameters + ---------- + name: str + The name of worker, should be in type. + workspace: MSCDirectory + The worksapce. + worker_id: int + The worker_id. + worker_cls: class + The worker class. + worker_config: dict + The worker config. + + Returns + ------- + worker: BaseWorker + The create worker. + """ + + obj_type, worker_id = name.split(":") + if obj_type == GYMObject.ENV: + env_type = config.pop("env_type") if "env_type" in config else "default" + worker_cls = msc_utils.get_registered_gym_env(env_type) + return EnvWorker(name, workspace, int(worker_id), worker_cls, config) + if obj_type == GYMObject.AGENT: + agent_type = config.pop("agent_type") if "agent_type" in config else "default" + worker_cls = msc_utils.get_registered_gym_agent(agent_type) + return AgentWorker(name, workspace, int(worker_id), worker_cls, config) + raise TypeError("Worker for {} is not supported".format(obj_type)) diff --git a/python/tvm/contrib/msc/core/gym/environment/__init__.py b/python/tvm/contrib/msc/core/gym/environment/__init__.py new file mode 100644 index 000000000000..211b02d32f3a --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/environment/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.environment""" + +from .method import * +from .prune_env import * +from .quantize_env import * diff --git a/python/tvm/contrib/msc/core/gym/environment/base_env.py b/python/tvm/contrib/msc/core/gym/environment/base_env.py new file mode 100644 index 000000000000..edaea5120556 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/environment/base_env.py @@ -0,0 +1,366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.base_env""" + +import copy +import logging +from typing import Dict, Any, List, Tuple +from tvm.contrib.msc.core.runtime import BaseRunner +from tvm.contrib.msc.core.tools import BaseTool +from tvm.contrib.msc.core import utils as msc_utils + + +class BaseEnv(object): + """Basic Environment of MSC.Gym + + Parameters + ---------- + runner: BaseRunner + The runner. + data_loader: + The data_loader + workspace: MSCDirectory + The worksapce. + executors: dict + The executors of the environment. + knowledge: dict + The predefined knowledge. + options: dict + The extra options for the environment. + debug_level: int + The debug level. + verbose_step: int + The verbose interval step. + logger: logging.Logger + The logger + """ + + def __init__( + self, + name: str, + runner: BaseRunner, + data_loader: Any, + workspace: msc_utils.MSCDirectory, + executors: dict, + knowledge: dict = None, + options: dict = None, + max_tasks: int = -1, + debug_level: int = 0, + logger: logging.Logger = None, + ): + self._name = name + self._runner = runner + self._data_loader = data_loader + self._workspace = workspace + self._knowledge = knowledge + self._executors = self._parse_executors(msc_utils.copy_dict(executors)) + self._options = options or {} + self._max_tasks = max_tasks + self._debug_level = debug_level + if logger: + self._logger = logger + else: + verbose = "debug" if debug_level > 0 else "info" + self._logger = msc_utils.create_file_logger(verbose, workspace.relpath("ENV_LOG")) + self._logger.info( + msc_utils.msg_block("ENV.SETUP({})".format(self.env_type()), self.setup()) + ) + + def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, dict]]: + """Parse the executors + + Parameters + ---------- + executors_dict: dict + The given executors. + + Returns + ------- + executors_dict: dict + The parsed executors. + """ + + executors = {} + for name, raw_config in executors_dict.items(): + method_type = ( + raw_config.pop("method_type") if "method_type" in raw_config else "env.default" + ) + method_cls = msc_utils.get_registered_gym_method(method_type) + assert "method" in raw_config, "method should be given to find enviironment method" + method_name, method = raw_config.pop("method"), None + if hasattr(method_cls, method_name): + method = getattr(method_cls, method_name) + if not method: + method = msc_utils.get_registered_func(method_name) + assert method, "Can not find method " + str(method_name) + executors[name] = (method_name, method, copy.deepcopy(raw_config)) + return executors + + def setup(self) -> dict: + """Setup the environment + + Returns + ------- + info: dict + The setup info. + """ + + self._cache_dir = self._workspace.create_dir("Cache") + self._tasks = [] + return { + "name": self._name, + "runner": self._runner, + "data_loader": self._data_loader, + "workspace": self._workspace, + "executors": {k: "{}({})".format(v[0], v[2]) for k, v in self._executors.items()}, + "options": self._options, + "max_tasks": self._max_tasks, + "debug_level": self._debug_level, + } + + def init(self) -> Tuple[int, Dict[str, Any]]: + """Init the agent + + Returns + ------- + max_tasks: int + The max task for agent. + baseline: dict + The baseline of environment. + """ + + self._runner.change_logger(self._logger) + # save cache for tasks + self._runner.save_cache(self._cache_dir) + self._tool = self._init_tool() + # create tasks + self._tasks = self._execute("create_tasks", self._tool) + if self._max_tasks > 0: + self._tasks = self._tasks[: self._max_tasks] + # get baseline + self._tool.disable() + self._runner.build(self._cache_dir, force_build=True) + baseline = self._reward_runner(-1) + self._tool.enable() + tasks_info = {"tasks_num": len(self._tasks), "tasks": self._tasks} + self._logger.info(msc_utils.msg_block("ENV.TASKS", tasks_info, width=0)) + return len(self._tasks), baseline + + def _init_tool(self) -> BaseTool: + """Get the main tool""" + + raise NotImplementedError("_init_tool is not implemented in BaseEnv") + + def reset(self) -> Tuple[List[float], List[dict]]: + """Reset the environment + + Returns + ------- + observation: list + The next observation. + action_space: list + The next action space. + """ + + return None + + def get_state(self, task_id: int) -> Tuple[List[float], List[dict]]: + """Get the state + + Parameters + ---------- + task_id: int + The current task id. + + Returns + ------- + observation: list + The next observation. + action_space: list + The next action space. + """ + + if "observation" in self._executors: + observation = self._execute("observation", task_id) + else: + observation = [task_id] + if "action_space" in self._executors: + action_space = self._execute("action_space", task_id) + else: + action_space = list(range(5)) + return observation, action_space + + def step(self, actions: List[dict], task_id: int) -> Tuple[List[float], List[dict], List[dict]]: + """Step and get rewards + + Parameters + ---------- + actions: list + The current actions. + task_id: int + The current task id. + + Returns + ------- + observation: list + The next observation. + action_space: list + The next action space. + rewards: list + The rewards + """ + + rewards = [] + for idx, action in enumerate(actions): + self._update_tool(action, task_id) + self._runner.build(self._cache_dir, force_build=True) + rewards.append(self._reward_runner(task_id)) + self._logger.info( + "Task[%d/%d] Action[%d/%d] %s -> reward %s", + task_id, + len(self._tasks), + idx, + len(actions), + action, + rewards[-1], + ) + return rewards + + def _update_tool(self, action: dict, task_id: int): + """Update the tool + + Parameters + ---------- + action: dict + The current action. + task_id: int + The current task id. + """ + + raise NotImplementedError("_update_tool is not implemented in BaseEnv") + + def summary(self, actions: List[dict], rewards: List[dict]) -> dict: + """Summary the final plan + + Parameters + ---------- + actions: list + The final actions. + rewards: list + The final rewards. + + Returns + ------- + plan: dict + The final plan. + """ + + self._logger.info("Env Summary with %d actions, %d rewards", len(actions), len(rewards)) + return self._summary(actions, rewards) + + def _summary(self, actions: List[dict], rewards: List[dict]) -> dict: + """Summary the final plan + + Parameters + ---------- + actions: list + The final actions. + rewards: list + The final rewards. + + Returns + ------- + plan: dict + The final plan. + """ + + raise NotImplementedError("_summary is not implemented in BaseEnv") + + def get_task(self, task_id: int) -> dict: + """Get task according to task_id + + Parameters + ---------- + task_id: int + The task id. + + Returns + ------- + task_config: dict + The task config. + """ + + return self._tasks[task_id] + + def destory(self): + """Destory the environment""" + + return None + + def _reward_runner(self, task_id: int) -> dict: + """Reward runner for current task + + Parameters + ---------- + task_id: int + The current task id. + + Returns + ------- + reward: dict + The reward + """ + + if "reward_runner" in self._executors: + return self._execute("reward_runner", self._runner, self._data_loader, task_id) + elif "reward_outputs" in self._executors: + reward = {} + for inputs in self._data_loader(): + outputs = self._runner.run(inputs) + reward = self._execute("reward_outputs", reward, outputs, task_id) + return reward + else: + raise Exception("reward_runner or reward_outputs should be given in executors") + + def _execute(self, name: str, *args, **kwargs) -> Any: + """Run executor + + Parameters + ---------- + name: str + The executor name. + args: list + The arguments for execute. + kwargs: dict + The key word arguments for execute. + + Returns + ------- + res: + The execute result. + """ + + assert name in self._executors, "Can not find {} in executors: {}".format( + name, self._executors.keys() + ) + _, method, config = self._executors[name] + kwargs.update({k: v for k, v in config.items() if k not in kwargs}) + return method(self, *args, **kwargs) + + @classmethod + def env_type(cls): + return "base" diff --git a/python/tvm/contrib/msc/core/gym/environment/method.py b/python/tvm/contrib/msc/core/gym/environment/method.py new file mode 100644 index 000000000000..66fe573d932f --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/environment/method.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.core.gym.agent.method""" + +from typing import Any, List +import numpy as np + +from tvm.contrib.msc.core.runtime import BaseRunner +from tvm.contrib.msc.core.tools import BaseTool +from tvm.contrib.msc.core import utils as msc_utils + + +class EnvMethod(object): + """Default prune method""" + + @classmethod + def tasks_tool_extract(cls, env: Any, tool: BaseTool, **kwargs) -> List[dict]: + """Extract tasks from tool + + Parameters + ---------- + env: BaseEnv + The evironment. + tool: BaseTool + The main tool + kwargs: dict + The kwargs for create tasks. + + Returns + ------- + tasks: list + The tasks for environment. + """ + + return tool.create_tasks(**kwargs) + + @classmethod + def reward_compare_baseline( + cls, + env: Any, + runner: BaseRunner, + data_loader: callable, + task_id: int, + loss_type: str = "lp_norm", + loss_config: dict = None, + ) -> dict: + """Reward runner with baseline + + Parameters + ---------- + env: BaseEnv + The evironment. + runner: BaseRunner + The runner. + data_loader: callable + The data loader. + task_id: int + The task id. + loss_type: str + The loss type + loss_config: dict + The loss config + + Returns + ------- + reward: dict + The reward. + """ + + datas_path = env._workspace.create_dir("Baseline").path + if task_id == -1: + with msc_utils.SimpleDataSaver(datas_path) as saver: + for inputs in data_loader(): + outputs = runner.run(inputs) + saver.save_datas(outputs) + return {"loss": 1} + + loss_config = loss_config or {} + loader, loss = msc_utils.SimpleDataLoader(datas_path), 0 + + def _get_loss(golden, result): + if loss_type == "lp_norm": + power = loss_config.get("power", 2) + return np.mean(np.power(np.abs(golden - result), power)) + raise NotImplementedError("loss type {} is not implemented".format(loss_type)) + + for idx, inputs in enumerate(data_loader()): + outputs = runner.run(inputs) + baseline = loader[idx] + for name, data in outputs.items(): + loss += _get_loss(baseline[name], data) + return {"loss": loss / len(loader)} + + @classmethod + def action_linear_space( + cls, env: Any, task_id: int, start: float = 0.1, end: float = 0.9, step: float = 0.1 + ) -> List[float]: + """Get linear action space + + Parameters + ---------- + env: BaseEnv + The evironment. + task_id: int + The task id. + start: float + The start value. + end: float + The end value. + step: float + The step value. + + Returns + ------- + actions: list + The actions. + """ + + actions = [start] + while actions[-1] < end: + actions.append(actions[-1] + step) + return actions + + @classmethod + def action_prune_density( + cls, env: Any, task_id: int, start: float = 0.1, end: float = 0.9, step: float = 0.1 + ) -> List[dict]: + """Get linear density + + Parameters + ---------- + env: BaseEnv + The evironment. + task_id: int + The task id. + start: float + The start value. + end: float + The end value. + step: float + The step value. + + Returns + ------- + actions: list + The actions. + """ + + return [{"density": a} for a in cls.action_linear_space(env, task_id, start, end, step)] + + @classmethod + def action_quantize_scale( + cls, env: Any, task_id: int, start: float = 0.1, end: float = 0.9, step: float = 0.1 + ) -> List[dict]: + """Get linear density + + Parameters + ---------- + env: BaseEnv + The evironment. + task_id: int + The task id. + start: float + The start value. + end: float + The end value. + step: float + The step value. + + Returns + ------- + actions: list + The actions. + """ + + task = env.get_task(task_id) + return [ + {"scale": task["scale"] * a} + for a in cls.action_linear_space(env, task_id, start, end, step) + ] + + @classmethod + def method_type(cls): + return "env.default" + + +msc_utils.register_gym_method(EnvMethod) diff --git a/python/tvm/contrib/msc/core/gym/environment/prune_env.py b/python/tvm/contrib/msc/core/gym/environment/prune_env.py new file mode 100644 index 000000000000..8f8a53567ef8 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/environment/prune_env.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.prune_env""" + +from typing import List +from tvm.contrib.msc.core.tools import BaseTool, ToolType +from tvm.contrib.msc.core import utils as msc_utils +from .base_env import BaseEnv + + +class PruneEnv(BaseEnv): + """Environment for prune""" + + def _init_tool(self) -> BaseTool: + """Get the main tool""" + + config = self._runner.get_tool_config(ToolType.PRUNER) + self._meta_strategys = config["strategys"] + for s in self._meta_strategys: + s.update({"density": 1}) + return self._runner.get_tool(ToolType.PRUNER) + + def _update_tool(self, action: dict, task_id: int): + """Update the tool + + Parameters + ---------- + action: dict + The current action. + task_id: int + The current task id. + """ + + task_strategy = self._get_strategy(action, task_id) + self._tool.plan_by_strategys(self._meta_strategys + [task_strategy]) + + def _summary(self, actions: List[dict], rewards: List[dict]) -> dict: + """Summary the final plan + + Parameters + ---------- + actions: list + The final actions. + rewards: list + The final rewards. + + Returns + ------- + plan: dict + The final plan. + """ + + strategys = [self._get_strategy(act, idx) for idx, act in enumerate(actions)] + return self._tool.plan_by_strategys(self._meta_strategys + strategys) + + def _get_strategy(self, action: dict, task_id: int) -> dict: + """Get strategy from task_id + + Parameters + ---------- + action: float + The current action. + task_id: int + The current task id. + + Returns + ------- + strategy: dict + The strategy. + """ + + strategy = msc_utils.copy_dict(self.get_task(task_id)) + strategy.update(**action) + return strategy + + @classmethod + def env_type(cls): + return msc_utils.MSCStage.PRUNE + ".default" + + +msc_utils.register_gym_env(PruneEnv) diff --git a/python/tvm/contrib/msc/core/gym/environment/quantize_env.py b/python/tvm/contrib/msc/core/gym/environment/quantize_env.py new file mode 100644 index 000000000000..0a5210b83032 --- /dev/null +++ b/python/tvm/contrib/msc/core/gym/environment/quantize_env.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.gym.quantize_env""" + +import os +from typing import List +from tvm.contrib.msc.core.tools import BaseTool, ToolType +from tvm.contrib.msc.core import utils as msc_utils +from .base_env import BaseEnv + + +class QuantizeEnv(BaseEnv): + """Environment for quantize""" + + def _init_tool(self) -> BaseTool: + """Get the main tool""" + + plan_file = self._runner.apply_tool(ToolType.QUANTIZER, self._data_loader) + self._meta_plan = msc_utils.load_dict(plan_file) + os.remove(plan_file) + return self._runner.get_tool(ToolType.QUANTIZER) + + def _update_tool(self, action: dict, task_id: int): + """Update the tool + + Parameters + ---------- + action: dict + The current action. + task_id: int + The current task id. + """ + + plan = msc_utils.copy_dict(self._meta_plan) + plan.update(self._get_plan(action, task_id)) + self._tool.set_plan(plan) + + def _summary(self, actions: List[dict], rewards: List[dict]) -> dict: + """Summary the final plan + + Parameters + ---------- + actions: list + The final actions. + rewards: list + The final rewards. + + Returns + ------- + plan: dict + The final plan. + """ + + plan = msc_utils.copy_dict(self._meta_plan) + for idx, act in enumerate(actions): + plan.update(self._get_plan(act, idx)) + return plan + + def _get_plan(self, action: dict, task_id: int) -> dict: + """Get plan from task_id + + Parameters + ---------- + action: float + The current action. + task_id: int + The current task id. + + Returns + ------- + plan: dict + The plan. + """ + + plan = msc_utils.copy_dict(self.get_task(task_id)) + plan.update(**action) + name = plan.pop("name") + return {name: plan} + + @classmethod + def env_type(cls): + return msc_utils.MSCStage.QUANTIZE + ".default" + + +msc_utils.register_gym_env(QuantizeEnv) diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index 577fe2bf4f9a..a946ef1611e7 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -448,6 +448,60 @@ def get_meta_data(self, name: str) -> np.ndarray: "Can not find data {} from {} weights".format(name, len(self._meta_weights)) ) + def create_tasks(self, **kwargs) -> List[dict]: + """Create tasks for gym + + Parameters + ---------- + kwargs: dict + The kwargs for create tasks. + + Returns + ------- + tasks: list + The tasks. + """ + + tasks = [] + for w_node in self.get_w_nodes(): + if w_node.get_attr("weight_strategy") != "main": + continue + consumer = self.find_producer(w_node.name).name + strategy = self._get_tensor_strategy(w_node.name, consumer) + tasks.append( + { + "tensor_names": [self.to_tensor_id(w_node.name, consumer)], + **strategy.meta, + } + ) + return tasks + + def plan_by_strategys(self, strategys: List[dict]) -> dict: + """Plan the pruning with startegys and get plan + + Parameters + ------- + strategys: list + The given strategys + + Returns + ------- + plan: dict + The plan after new strategy applied. + """ + + self._tensor_cache, self._processed_tensor = {}, {} + self._plan = {} + self._strategys = self._parse_strategys(msc_utils.copy_dict(strategys)) + info = {k: v.inspect() for k, v in self._strategys.items()} + title = "{}.PRUNE_STRATEGYS".format(self.tool_type().upper()) + self._logger.debug(msc_utils.msg_block(title, info, width=0)) + for w_node in self.get_w_nodes(): + consumer = self.find_consumers(w_node.name)[0] + self.process_tensor(w_node.weight, w_node.name, consumer.name, "") + self._plan = {n: c for n, c in self._plan.items() if c["in_indices"] or c["out_indices"]} + return self._plan + def finalize(self) -> dict: """Get the plan""" diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index 480705f31b52..adeea3b2226a 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -814,6 +814,22 @@ def _process_tensor( return tensor + def create_tasks(self, **kwargs) -> List[dict]: + """Create tasks for gym + + Parameters + ---------- + kwargs: dict + The kwargs for create tasks. + + Returns + ------- + tasks: list + The tasks. + """ + + return [] + def config_generate(self, generate_config: Dict[str, Any]) -> Dict[str, Any]: """Update the generate configs diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py index ada9745ff6fb..5a6f6c0ff649 100644 --- a/python/tvm/contrib/msc/core/utils/file.py +++ b/python/tvm/contrib/msc/core/utils/file.py @@ -362,6 +362,7 @@ def to_abs_path(path: str, root_dir: MSCDirectory = None, keep_history: bool = T get_cache_dir = partial(get_workspace_subdir, name="Cache") get_config_dir = partial(get_workspace_subdir, name="Config") get_dataset_dir = partial(get_workspace_subdir, name="Dataset") +get_gym_dir = partial(get_workspace_subdir, name="Gym") get_output_dir = partial(get_workspace_subdir, name="Output") get_visual_dir = partial(get_workspace_subdir, name="Visual") get_weights_dir = partial(get_workspace_subdir, name="Weights") diff --git a/python/tvm/contrib/msc/core/utils/register.py b/python/tvm/contrib/msc/core/utils/register.py index 31ae8942a106..50f0b8cd17b3 100644 --- a/python/tvm/contrib/msc/core/utils/register.py +++ b/python/tvm/contrib/msc/core/utils/register.py @@ -27,6 +27,11 @@ class MSCRegistery: MSC_FUNCS = "msc_funcs" MSC_TOOLS_CLS = "msc_tools_cls" MSC_TOOLS_METHOD = "msc_tools_method" + GYM_CONFIGERS = "gym_configers" + GYM_CONTROLLERS = "gym_controllers" + GYM_AGENTS = "gym_agents" + GYM_ENVS = "gym_envs" + GYM_METHODS = "gym_agents_method" @classmethod def register(cls, key: str, value: Any): @@ -185,3 +190,170 @@ def get_registered_tool_method( tools_method = MSCRegistery.get(MSCRegistery.MSC_TOOLS_METHOD, {}) register_name = "{}.{}".format(tool_type, method_style) return tools_method.get(framework, {}).get(register_name) + + +def register_gym_configer(configer: Any): + """Register a gym configer. + + Parameters + ---------- + configer: class + The configer class. + """ + + configers = MSCRegistery.get(MSCRegistery.GYM_CONFIGERS, {}) + assert hasattr(configer, "config_type"), "config_type should be given to register configer" + configers[configer.config_type()] = configer + MSCRegistery.register(MSCRegistery.GYM_CONFIGERS, configers) + + +def get_registered_gym_configer(config_type: str) -> Any: + """Get the registered configer. + + Parameters + ---------- + config_type: string + The type of configer. + + Returns + ------- + configer: class + The configer class. + """ + + configers = MSCRegistery.get(MSCRegistery.GYM_CONFIGERS, {}) + return configers.get(config_type) + + +def register_gym_controller(controller: Any): + """Register a gym controller. + + Parameters + ---------- + controller: class + The controller class. + """ + + controllers = MSCRegistery.get(MSCRegistery.GYM_CONTROLLERS, {}) + assert hasattr( + controller, "control_type" + ), "control_type should be given to register controller" + controllers[controller.control_type()] = controller + MSCRegistery.register(MSCRegistery.GYM_CONTROLLERS, controllers) + + +def get_registered_gym_controller(control_type: str) -> Any: + """Get the registered controller. + + Parameters + ---------- + control_type: string + The type of controller. + + Returns + ------- + controller: class + The controller class. + """ + + controllers = MSCRegistery.get(MSCRegistery.GYM_CONTROLLERS, {}) + return controllers.get(control_type) + + +def register_gym_agent(agent: Any): + """Register a gym agent. + + Parameters + ---------- + agent: class + The agent class. + """ + + agents = MSCRegistery.get(MSCRegistery.GYM_AGENTS, {}) + assert hasattr(agent, "agent_type"), "agent_type should be given to register agent" + agents[agent.agent_type()] = agent + MSCRegistery.register(MSCRegistery.GYM_AGENTS, agents) + + +def get_registered_gym_agent(agent_type: str) -> Any: + """Get the registered agent. + + Parameters + ---------- + agent_type: string + The type of agent. + + Returns + ------- + agent: class + The agent class. + """ + + agents = MSCRegistery.get(MSCRegistery.GYM_AGENTS, {}) + return agents.get(agent_type) + + +def register_gym_env(env: Any): + """Register a gym env. + + Parameters + ---------- + env: class + The env class. + """ + + envs = MSCRegistery.get(MSCRegistery.GYM_ENVS, {}) + assert hasattr(env, "env_type"), "env_type should be given to register env" + envs[env.env_type()] = env + MSCRegistery.register(MSCRegistery.GYM_ENVS, envs) + + +def get_registered_gym_env(env_type: str) -> Any: + """Get the registered env. + + Parameters + ---------- + env_type: string + The type of agent. + + Returns + ------- + env: class + The agent class. + """ + + envs = MSCRegistery.get(MSCRegistery.GYM_ENVS, {}) + return envs.get(env_type) + + +def register_gym_method(method: Any): + """Register a gym method. + + Parameters + ---------- + method: class + The method class. + """ + + methods = MSCRegistery.get(MSCRegistery.GYM_METHODS, {}) + assert hasattr(method, "method_type"), "method_type should be given to register method" + methods[method.method_type()] = method + MSCRegistery.register(MSCRegistery.GYM_METHODS, methods) + + +def get_registered_gym_method(method_type: str) -> Any: + """Get the registered agent. + + Parameters + ---------- + method_type: str + The type of method. + + Returns + ------- + method: class + The method class. + """ + + methods = MSCRegistery.get(MSCRegistery.GYM_METHODS, {}) + return methods.get(method_type) diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py index 318334d33302..585e1dc82584 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py @@ -114,15 +114,16 @@ def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: folder.copy(path, info["src"]) return folder.move(engine_name + ".trt", output_folder.relpath(engine_name + ".trt")) - codegen = CodeGen( - graph, - _ffi_api.GetTensorRTSources, - codegen_config, - print_config, - build_folder.create_dir(graph.name), - code_format="cpp", - ) - engine_file = codegen.load([], pre_load=_create_depends, post_load=_build_engine) + with build_folder as folder: + codegen = CodeGen( + graph, + _ffi_api.GetTensorRTSources, + codegen_config, + print_config, + folder.create_dir(graph.name), + code_format="cpp", + ) + engine_file = codegen.load([], pre_load=_create_depends, post_load=_build_engine) return { "graph_json": graph.to_json(), "graph_name": graph.name, diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py index d19c7995ed7d..9a45ee84b394 100644 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ b/python/tvm/contrib/msc/pipeline/manager.py @@ -19,6 +19,7 @@ import os import time +import json from typing import Dict, Any import traceback import numpy as np @@ -29,6 +30,7 @@ from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.core.gym.control import create_controller class BaseManager(object): @@ -561,7 +563,25 @@ def _apply_tool(self, tool_type: str, stage_config: dict, add_tool: bool = True) tool_stage, t_stage_config, tools_config=tools_config, profile=False, use_cache=False ) if gym_configs: - raise NotImplementedError("Gym is not implemented") + knowledge = None + for idx, config in enumerate(gym_configs): + self._logger.info("GYM[%d/%d].CREATE(%s)", idx, len(gym_configs), tool_stage) + extra_config = { + "env": { + "runner": runner, + "data_loader": self._data_loader, + "knowledge": knowledge, + }, + "debug_level": runner.debug_level, + } + controller = create_controller(runner.stage, config, extra_config) + knowledge = controller.run() + with open(plan_file, "w") as f: + f.write(json.dumps(knowledge, indent=2)) + self._logger.info( + "Gym save %d knowledge(%s) -> %s", len(knowledge), tool_type, plan_file + ) + return plan_file return runner.apply_tool(tool_type, self._data_loader) def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py index 5723e3d9ffc7..61d1001ea849 100644 --- a/python/tvm/relax/frontend/nn/__init__.py +++ b/python/tvm/relax/frontend/nn/__init__.py @@ -17,7 +17,7 @@ """A PyTorch-like API to build IRModules.""" # pylint: disable=redefined-builtin from . import op, spec -from .core import Effect, Module, ModuleList, Parameter, Tensor +from .core import Effect, Module, ModuleList, Object, Parameter, Tensor from .exporter import add_extern from .extern import ExternModule, ObjectModule, SourceModule from .modules import ( diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 8ed0efe2cd04..9c99ba6177d2 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -50,7 +50,12 @@ from ... import expr as rx from ...block_builder import BlockBuilder -from ...struct_info import ShapeStructInfo, TensorStructInfo, TupleStructInfo +from ...struct_info import ( + ObjectStructInfo, + ShapeStructInfo, + TensorStructInfo, + TupleStructInfo, +) from ._tensor_op import _TensorOp from .subroutine import SubroutineMixin @@ -274,6 +279,22 @@ def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-na )._expr +class Object: + """A wrapper on top of relax.Expr whose struct_info is the base + ObjectStructInfo (rather than any its subclass). Object effectively + represents non-tensor frontend components such as KV caches. + """ + + _expr: rx.Var + + def __init__(self, *, _expr: rx.Expr, _name: str) -> None: + """Private constructor. Object is never supposed to be constructed directly by users.""" + if not isinstance(_expr, rx.Var): + _expr = BlockBuilder.current().emit(_expr, _name) + self._expr = _expr + assert isinstance(self._expr.struct_info, ObjectStructInfo) + + class Effect: """Effect is a special non-user facing type that is used to represent operations with side effects, for example, print. It is used to represent the output of a computation. diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index 416913def48b..99591c8a3e2e 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -23,7 +23,7 @@ from ... import expr as rx from ...block_builder import BlockBuilder -from ...struct_info import ShapeStructInfo, TupleStructInfo +from ...struct_info import ObjectStructInfo, ShapeStructInfo, TupleStructInfo from . import core, extern from . import spec as _spec from .modules import IOEffect @@ -160,7 +160,7 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many- ): # pylint: disable=protected-access def _unwrap_ret(expr: typing.Any) -> typing.Any: - if isinstance(expr, core.Tensor): + if isinstance(expr, (core.Tensor, core.Object)): return expr._expr if isinstance(expr, tuple): return rx.Tuple([_unwrap_ret(x) for x in expr]) @@ -171,7 +171,7 @@ def _unwrap_ret(expr: typing.Any) -> typing.Any: def _convert_input(arg): if isinstance(arg, tir.Var): return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg])) - if isinstance(arg, core.Tensor): + if isinstance(arg, (core.Tensor, core.Object)): return arg._expr # pylint: disable=protected-access if isinstance(arg, _spec.Tuple): return rx.Var( @@ -292,6 +292,8 @@ def _convert_input(arg_name, arg_spec): dtype=arg_spec.dtype, name=arg_name, ) + elif isinstance(arg_spec, _spec.Object): + arg = arg_spec.object_type(_expr=rx.Var(arg_name, ObjectStructInfo()), _name=arg_name) elif isinstance(arg_spec, _spec.Tuple): elements = type(arg_spec.elements)( [ diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 3197145289ef..66f023ef9ded 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1142,6 +1142,65 @@ def zeros( return wrap_nested(_op.zeros(shape, dtype), name) +def ones( + shape: Sequence[IntExpr], + dtype: str = "float32", + name: str = "ones", +) -> Tensor: + """Construct a tensor of all zeros, with the input shape and dtype. + + Parameters + ---------- + shape : Sequence[IntExpr] + The shape of the created tensor. + + dtype : str + The data type of the created tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The result tensor. + """ + return wrap_nested(_op.ones(shape, dtype), name) + + +def empty( + shape: Sequence[IntExpr], + dtype: str = "float32", + name: str = "empty", +) -> Tensor: + """Construct an uninitialized tensor, with the input shape and dtype. + + Parameters + ---------- + shape : Sequence[IntExpr] + The shape of the created tensor. + + dtype : str + The data type of the created tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The result tensor. + """ + return wrap_nested( # type: ignore + _op.builtin.alloc_tensor( + rx.ShapeExpr(shape), # type: ignore + dtype, + runtime_device_index=0, + ), + name, + ) + + def split( ary: Tensor, indices_or_sections: Union[int, Sequence[int]], diff --git a/python/tvm/relax/frontend/nn/spec.py b/python/tvm/relax/frontend/nn/spec.py index 210b16ce013a..54928ce07b80 100644 --- a/python/tvm/relax/frontend/nn/spec.py +++ b/python/tvm/relax/frontend/nn/spec.py @@ -24,7 +24,7 @@ ArgSpecType = typing.Union["Int", "Tensor"] MethodSpecType = typing.Union["MethodSpec", typing.Dict[str, ArgSpecType]] ModuleSpecType = typing.Union["ModuleSpec", typing.Dict[str, MethodSpecType]] -SpecAny = typing.Union["Int", "Tensor", "Tuple"] +SpecAny = typing.Union["Object", "Int", "Tensor", "Tuple"] class Int: # pylint: disable=too-few-public-methods @@ -52,6 +52,18 @@ def __repr__(self) -> str: return f"Tensor([{shape}], '{self.dtype}')" +class Object: # pylint: disable=too-few-public-methods + """An non-tensor opaque frontend object.""" + + object_type: typing.Type + + def __init__(self, object_type: typing.Type) -> None: + self.object_type = object_type + + def __repr__(self) -> str: + return "object" + + class Tuple: # pylint: disable=too-few-public-methods """A tuple input or a list input""" @@ -141,7 +153,7 @@ def _convert_arg_spec(arg_spec, arg_name): return Int() if isinstance(arg_spec, str) and arg_spec == "int": return Int() - if isinstance(arg_spec, (Int, Tensor)): + if isinstance(arg_spec, (Int, Tensor, Object)): return arg_spec if isinstance(arg_spec, (tuple, list, Tuple)): return Tuple( diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 4a2a1555ff46..3873f624efcf 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -219,6 +219,12 @@ class TokenAllocator1D { available_pool_[token->dtype].insert({token->bytes, token}); } + /*! \brief Clear the allocator. */ + void Clear() { + available_pool_.clear(); + full_pool_.clear(); + } + private: /*! \brief A constant scale representing the token search range. */ const int match_range_{16}; @@ -569,6 +575,8 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { if (func == nullptr) { continue; } + // Clear the allocator to make the planning of different functions independent. + allocator_.Clear(); this->VisitExpr_(func); } } diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index fc7d351e5b77..e941908dbc2a 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -848,9 +848,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { if (num_depths_ == 1) { if (use_decode_kernel_[0]) { f_attention_decode_begin_forward_( - /*depth=*/0, page_indptr_on_depths_view_[0], page_indices_on_depths_view_[0], - last_page_len_on_depths_view_[0], /*return_lse=*/true, num_qo_heads_, num_kv_heads_, - head_dim_, page_size_, /*rotary_mode=*/true); + /*depth=*/0, page_indptr_on_depths_view_[0], last_page_len_on_depths_view_[0], + num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/true); } else { f_attention_prefill_begin_forward_(/*depth=*/0, qo_indptr_on_depths_view_[0], cur_batch_size_, num_qo_heads_, num_kv_heads_); @@ -864,9 +863,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { } if (use_decode_kernel_[d]) { f_attention_decode_begin_forward_( - d, page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - last_page_len_on_depths_view_[d], /*rotary_mode=*/false, num_qo_heads_, num_kv_heads_, - head_dim_, page_size_, /*return_lse=*/true); + d, page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_, + num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/false); } else { f_attention_prefill_begin_forward_(/*depth=*/d, qo_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d]->shape[0], diff --git a/tests/python/contrib/test_msc/test_manager.py b/tests/python/contrib/test_msc/test_manager.py index 393c8decbc79..c3e1583ef291 100644 --- a/tests/python/contrib/test_msc/test_manager.py +++ b/tests/python/contrib/test_msc/test_manager.py @@ -34,6 +34,7 @@ def _get_config(model_type, compile_type, inputs, outputs, atol=1e-2, rtol=1e-2): """Get msc config""" return { + "workspace": msc_utils.msc_dir(), "model_type": model_type, "inputs": inputs, "outputs": outputs, diff --git a/tests/python/contrib/test_msc/test_runner.py b/tests/python/contrib/test_msc/test_runner.py index a6005f5d41c2..cbc33d452846 100644 --- a/tests/python/contrib/test_msc/test_runner.py +++ b/tests/python/contrib/test_msc/test_runner.py @@ -82,7 +82,7 @@ def _test_from_torch(runner_cls, device, is_training=False, atol=1e-3, rtol=1e-3 torch_model = _get_torch_model("resnet50", is_training) if torch_model: - workspace = msc_utils.set_workspace() + workspace = msc_utils.set_workspace(msc_utils.msc_dir()) log_path = workspace.relpath("MSC_LOG", keep_history=False) msc_utils.set_global_logger("info", log_path) input_info = [([1, 3, 224, 224], "float32")] @@ -139,7 +139,7 @@ def test_tensorflow_runner(): tf_graph, graph_def = _get_tf_graph() if tf_graph and graph_def: - workspace = msc_utils.set_workspace() + workspace = msc_utils.set_workspace(msc_utils.msc_dir()) log_path = workspace.relpath("MSC_LOG", keep_history=False) msc_utils.set_global_logger("info", log_path) data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32") diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index cda6e231e827..6adf70605bfd 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -46,10 +46,10 @@ def _get_config( ): """Get msc config""" return { + "workspace": msc_utils.msc_dir(), "model_type": model_type, "inputs": inputs, "outputs": outputs, - "debug_level": 0, "dataset": {"loader": "from_random", "max_iter": 5}, "prepare": {"profile": {"benchmark": {"repeat": 10}}}, "baseline": { @@ -68,7 +68,7 @@ def _get_config( } -def get_tool_config(tool_type, use_distill=False): +def get_tool_config(tool_type, use_distill=False, use_gym=False): """Get config for the tool""" config = {} if tool_type == ToolType.PRUNER: @@ -76,6 +76,23 @@ def get_tool_config(tool_type, use_distill=False): "plan_file": "msc_pruner.json", "strategys": [{"method": "per_channel", "density": 0.8}], } + if use_gym: + config["gym_configs"] = [ + { + "env": { + "executors": { + "action_space": { + "method": "action_prune_density", + "start": 0.4, + "end": 0.8, + "step": 0.4, + } + }, + "max_tasks": 3, + }, + "agent": {"agent_type": "search.grid", "executors": {}}, + } + ] elif tool_type == ToolType.QUANTIZER: # pylint: disable=import-outside-toplevel from tvm.contrib.msc.core.tools.quantize import QuantizeStage @@ -113,6 +130,23 @@ def get_tool_config(tool_type, use_distill=False): }, ], } + if use_gym: + config["gym_configs"] = [ + { + "env": { + "executors": { + "action_space": { + "method": "action_quantize_scale", + "start": 0.8, + "end": 1.2, + "step": 0.2, + } + }, + "max_tasks": 3, + }, + "agent": {"agent_type": "search.grid", "executors": {}}, + } + ] elif tool_type == ToolType.TRACKER: config = { "plan_file": "msc_tracker.json", @@ -184,16 +218,16 @@ def _test_from_torch( ) manager = MSCManager(torch_model, config) report = manager.run_pipe() - assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type) + model_info = manager.runner.model_info for t_type, config in tools_config.items(): assert os.path.isfile( msc_utils.get_config_dir().relpath(config["plan_file"]) ), "Failed to find plan of " + str(t_type) - model_info = manager.runner.model_info + manager.destory() + assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type) assert msc_utils.dict_equal( model_info, expected_info ), "Model info {} mismatch with expected {}".format(model_info, expected_info) - manager.destory() def get_model_info(compile_type): @@ -249,6 +283,16 @@ def test_tvm_distill(tool_type): ) +@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER]) +def test_tvm_gym(tool_type): + """Test tools for tvm with distiller""" + + tool_config = get_tool_config(tool_type, use_gym=True) + _test_from_torch( + MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM), is_training=True + ) + + @requires_tensorrt @pytest.mark.parametrize( "tool_type", @@ -285,5 +329,16 @@ def test_tensorrt_distill(tool_type): ) +@requires_tensorrt +@pytest.mark.parametrize("tool_type", [ToolType.PRUNER]) +def test_tensorrt_gym(tool_type): + """Test tools for tensorrt with gym""" + + tool_config = get_tool_config(tool_type, use_gym=True) + _test_from_torch( + MSCFramework.TENSORRT, tool_config, get_model_info(MSCFramework.TENSORRT), is_training=False + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 55870426e485..43f4a9efc03f 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -17,12 +17,14 @@ # pylint: disable=missing-docstring, invalid-name import tvm import tvm.testing -from tvm import tir +from tvm import relax, tir from tvm.relax.frontend.nn import Module, Tensor, op, spec from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T +# mypy: disable-error-code="attr-defined,valid-type,name-defined" + def test_binary(): class Model(Module): @@ -174,7 +176,7 @@ class Model(Module): def test(self, x: Tensor, weight: Tensor, bias: Tensor): padded = op.pad(x, [0, 0, 0, 0, 1, 1, 1, 1]) conv2d = op.conv2d(padded, weight, bias) - interpolate = op.interpolate(x, size=[40, 40]) + interpolate = op.interpolate(x, size=[40, 40]) # type: ignore return (conv2d, interpolate) @R.function @@ -347,7 +349,7 @@ def test_create(): class Model(Module): def test(self, x: Tensor): triu_out = op.triu(x) - full_with_scalar_out = op.full([10, 10], fill_value=10) + full_with_scalar_out = op.full([10, 10], fill_value=10) # type: ignore full_with_FloatImm_out = op.full( [10, 10], fill_value=tir.FloatImm(dtype="float32", value=10) ) @@ -638,5 +640,24 @@ def test(q: R.Tensor((1, 1, 16, 8), dtype="float32"), k: R.Tensor((64, 16, 8), d tvm.ir.assert_structural_equal(irmodule, Expected) +def test_empty(): + @tvm.register_func("test_empty_assert", override=True) + def test_empty_assert(_lineo, x): + assert x.shape == (10, 10) + assert x.dtype == "float32" + + class Model(Module): + def test(self): + result = op.empty([10, 10], dtype="float32") + op.debug_func("test_empty_assert", result) + return result + + irmodule, _ = Model().export_tvm(spec={"test": {}}, debug=True) + ex = relax.build(irmodule, "llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + effects = vm["_initialize_effect"]() + vm["test"](*effects) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 0c24f90efc11..f12b5b9fc142 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -18,7 +18,9 @@ import tvm import tvm.testing from tvm import relax -from tvm.script import ir as I, relax as R, tir as T +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T def test_basic(): @@ -1105,5 +1107,74 @@ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"): tvm.ir.assert_structural_equal(mod, Expected) +def test_function_independence(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def exp(A: T.handle, B: T.handle): + T.evaluate(0) + + @R.function + def func1(x: R.Tensor((8,), dtype="float32")) -> R.Tensor((8,), dtype="float32"): + R.func_attr({"relax.force_pure": 1}) + cls = Module + alloc: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor(R.shape([8,]), dtype="float32", runtime_device_index=0) + _: R.Tuple() = cls.exp(x, alloc) + lv: R.Tensor((8,), dtype="float32") = alloc + alloc1: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor(R.shape([8,]), dtype="float32", runtime_device_index=0) + _1: R.Tuple() = cls.exp(lv, alloc1) + gv: R.Tensor((8,), dtype="float32") = alloc1 + return gv + + @R.function + def func2(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"relax.force_pure": 1}) + cls = Module + alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10,]), dtype="float32", runtime_device_index=0) + _: R.Tuple() = cls.exp(x, alloc) + lv: R.Tensor((10,), dtype="float32") = alloc + alloc1: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10,]), dtype="float32", runtime_device_index=0) + _1: R.Tuple() = cls.exp(lv, alloc1) + gv: R.Tensor((10,), dtype="float32") = alloc1 + return gv + + @I.ir_module + class Expected: + @T.prim_func + def exp(A: T.handle, B: T.handle): + T.evaluate(0) + + @R.function + def func1(x: R.Tensor((8,), dtype="float32")) -> R.Tensor((8,), dtype="float32"): + R.func_attr({"relax.force_pure": 1}) + cls = Expected + storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([8]), R.dtype("float32")) + _: R.Tuple = cls.exp(x, alloc) + lv: R.Tensor((8,), dtype="float32") = alloc + alloc1: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor(R.shape([8]), R.dtype("float32"), R.prim_value(0)) + _1: R.Tuple = cls.exp(lv, alloc1) + gv: R.Tensor((8,), dtype="float32") = alloc1 + return gv + + @R.function + def func2(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"relax.force_pure": 1}) + cls = Expected + storage1: R.Object = R.memory.alloc_storage(R.shape([40]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([10]), R.dtype("float32")) + _: R.Tuple = cls.exp(x, alloc) + lv: R.Tensor((10,), dtype="float32") = alloc + alloc1: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0)) + _1: R.Tuple = cls.exp(lv, alloc1) + gv: R.Tensor((10,), dtype="float32") = alloc1 + return gv + # fmt: on + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 60f40adbf40c..311bbd9971ac 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -126,8 +126,8 @@ TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet *ret = (obj.use_count() - 1); }); -void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format) { - if (format == "f32-to-bf16") { +void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, std::string dtype) { + if (format == "f32-to-bf16" && dtype == "float32") { std::vector buffer(bytes.length() / 2); std::memcpy(buffer.data(), bytes.data(), buffer.size() * 2); // decode bf16 to f32 diff --git a/web/src/runtime.ts b/web/src/runtime.ts index f842b2723f81..5aa38dee39b1 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -1556,7 +1556,7 @@ export class Instance implements Disposable { }); const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); // first sync copy to cpu. - this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format); + this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype); // then async stream into GPU if needed if (device.deviceType === DeviceStrToEnum.cpu) { this.ndarrayCacheUpdate(rec.name, cpu_arr, false);