Skip to content

Commit

Permalink
[MSC][M5.2] Enable quantize && prune with gym by wrapper (#16702)
Browse files Browse the repository at this point in the history
change register
  • Loading branch information
Archermmt authored Mar 13, 2024
1 parent 831d769 commit 3b25697
Show file tree
Hide file tree
Showing 56 changed files with 727 additions and 420 deletions.
11 changes: 9 additions & 2 deletions gallery/how_to/work_with_msc/using_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
parser.add_argument("--calibrate_iter", type=int, default=100, help="The iter for calibration")
parser.add_argument("--train_batch", type=int, default=32, help="The batch size for train")
parser.add_argument("--train_iter", type=int, default=200, help="The iter for train")
parser.add_argument("--train_epoch", type=int, default=5, help="The epoch for train")
parser.add_argument("--train_epoch", type=int, default=100, help="The epoch for train")
parser.add_argument(
"--verbose", type=str, default="info", help="The verbose level, info|debug:1,2,3|critical"
)
args = parser.parse_args()


Expand Down Expand Up @@ -86,7 +89,7 @@ def get_config(calib_loader, train_loader):
dataset=dataset,
tools=tools,
skip_config={"all": "check"},
verbose="info",
verbose=args.verbose,
)


Expand Down Expand Up @@ -130,3 +133,7 @@ def _get_train_datas():
model.compile()
acc = eval_model(model, testloader, max_iter=args.test_iter)
print("Compiled acc: " + str(acc))

# export the model
path = model.export()
print("Export model to " + str(path))
49 changes: 30 additions & 19 deletions python/tvm/contrib/msc/core/gym/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import logging
from typing import Dict, Any, List, Tuple
from tvm.contrib.msc.core.gym.namespace import GYMObject
from tvm.contrib.msc.core import utils as msc_utils


Expand All @@ -37,8 +38,6 @@ class BaseAgent(object):
The extra options for the agent.
debug_level: int
The debug level.
verbose: str
The verbose level.
logger: logging.Logger
The logger
"""
Expand All @@ -50,23 +49,15 @@ def __init__(
executors: dict,
options: dict = None,
debug_level: int = 0,
verbose: str = None,
logger: logging.Logger = None,
):
self._name = name
self._workspace = workspace
self._executors = self._parse_executors(msc_utils.copy_dict(executors))
self._options = options or {}
self._debug_level = debug_level
if logger:
self._logger = logger
else:
if not verbose:
verbose = "debug" if debug_level > 0 else "info"
self._logger = msc_utils.create_file_logger(verbose, workspace.relpath("AGENT_LOG"))
self._logger.info(
msc_utils.msg_block("AGENT.SETUP({})".format(self.agent_type()), self.setup())
)
self._logger = logger or msc_utils.get_global_logger()
self._logger.info(msc_utils.msg_block(self.agent_mark("SETUP"), self.setup()))

def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, dict]]:
"""Parse the executors
Expand All @@ -85,9 +76,12 @@ def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, di
executors = {}
for name, raw_config in executors_dict.items():
method_type = (
raw_config.pop("method_type") if "method_type" in raw_config else "agent.default"
raw_config.pop("method_type") if "method_type" in raw_config else "default"
)
method_cls = msc_utils.get_registered_gym_method(GYMObject.AGENT, method_type)
assert method_cls, "Can not find method cls for {}:{}".format(
GYMObject.AGENT, method_type
)
method_cls = msc_utils.get_registered_gym_method(method_type)
assert "method" in raw_config, "method should be given to find agent method"
method_name, method = raw_config.pop("method"), None
if hasattr(method_cls, method_name):
Expand Down Expand Up @@ -244,7 +238,7 @@ def learn(self):
The learned rewards.
"""

self._logger.debug(msc_utils.msg_block("AGENT.LEARN", self._knowledge))
self._logger.debug(msc_utils.msg_block(self.agent_mark("KNOWLEDEG"), self._knowledge))
return self._learn()

def _learn(self):
Expand Down Expand Up @@ -306,9 +300,26 @@ def _evaluate(self, reward: dict) -> float:

return self._execute("evaluate", self._baseline, reward)

@classmethod
def agent_type(cls):
return "base"
def agent_mark(self, msg: Any) -> str:
"""Mark the message with agent info
Parameters
-------
msg: str
The message
Returns
-------
msg: str
The message with mark.
"""

return "AGENT({}) {}".format(self.role_type(), msg)

msc_utils.register_gym_agent(BaseAgent)
@classmethod
def role(cls):
return GYMObject.AGENT

@classmethod
def role_type(cls):
return "base"
11 changes: 7 additions & 4 deletions python/tvm/contrib/msc/core/gym/agent/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
"""tvm.contrib.msc.core.gym.agent.method"""

from typing import Any
from tvm.contrib.msc.core.gym.namespace import GYMObject
from tvm.contrib.msc.core import utils as msc_utils


@msc_utils.register_gym_method
class AgentMethod(object):
"""Default prune method"""

Expand Down Expand Up @@ -73,8 +75,9 @@ def evaluate_by_thresh(cls, agent: Any, baseline: dict, reward: dict, thresh: fl
return reward["reward"]

@classmethod
def method_type(cls):
return "agent.default"

def role(cls):
return GYMObject.AGENT

msc_utils.register_gym_method(AgentMethod)
@classmethod
def method_type(cls):
return "default"
11 changes: 5 additions & 6 deletions python/tvm/contrib/msc/core/gym/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ def setup(self) -> dict:
return super().setup()

@classmethod
def agent_type(cls):
def role_type(cls):
return "search.base"


@msc_utils.register_gym_object
class GridSearchAgent(BaseSearchAgent):
"""GridSearch agent"""

Expand Down Expand Up @@ -92,10 +93,11 @@ def _learn(self):
return best_actions, best_rewards

@classmethod
def agent_type(cls):
def role_type(cls):
return "search.grid"


@msc_utils.register_gym_object
class BinarySearchAgent(BaseSearchAgent):
"""BinarySearch agent"""

Expand Down Expand Up @@ -173,8 +175,5 @@ def _learn(self):
return actions, rewards

@classmethod
def agent_type(cls):
def role_type(cls):
return "search.binary"


msc_utils.register_gym_agent(GridSearchAgent)
12 changes: 5 additions & 7 deletions python/tvm/contrib/msc/core/gym/control/configer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def update(self, raw_config: dict) -> dict:
raise NotImplementedError("update is not implemented in BaseConfiger")


@msc_utils.register_gym_configer
class DefaultConfiger(BaseConfiger):
"""Default configer for gym"""

Expand All @@ -67,10 +68,10 @@ def update(self, raw_config: dict) -> dict:

config = msc_utils.copy_dict(raw_config)
assert "env" in config and "agent" in config, "env and agent should be given to run gym"
if "env_type" not in config["env"]:
config["env"]["env_type"] = self._stage + ".default"
if "agent_type" not in config["agent"]:
config["agent"]["agent_type"] = "search.grid"
if "role_type" not in config["env"]:
config["env"]["role_type"] = self._stage + ".default"
if "role_type" not in config["agent"]:
config["agent"]["role_type"] = "search.grid"
if "executors" not in config["env"]:
config["env"]["executors"] = {}
# update executors
Expand All @@ -92,6 +93,3 @@ def update(self, raw_config: dict) -> dict:
@classmethod
def config_type(cls):
return "default"


msc_utils.register_gym_configer(DefaultConfiger)
6 changes: 2 additions & 4 deletions python/tvm/contrib/msc/core/gym/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
"""tvm.contrib.msc.core.gym.control.controller"""

from typing import Dict, Any
from tvm.contrib.msc.core.gym.namespace import GYMObject, GYMAction
from tvm.contrib.msc.core import utils as msc_utils
from .service import MainService, NodeService
from .namespace import GYMObject, GYMAction


class BaseController(object):
Expand Down Expand Up @@ -98,10 +98,8 @@ def create_controller(stage: str, config: dict, extra_config: dict = None):
return controller_cls(msc_utils.get_gym_dir(), config)


@msc_utils.register_gym_controller
class DefaultController(BaseController):
@classmethod
def control_type(cls):
return "default"


msc_utils.register_gym_controller(DefaultController)
46 changes: 28 additions & 18 deletions python/tvm/contrib/msc/core/gym/control/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import queue
import numpy as np

from tvm.contrib.msc.core.gym.namespace import GYMObject, GYMAction
from tvm.contrib.msc.core import utils as msc_utils
from .worker import BaseWorker, WorkerFactory
from .namespace import GYMObject, GYMAction
from .worker import BaseGymWorker, WorkerFactory


def _send_message(msg_queue: queue.Queue, header: str, body: dict, header_type: str = "message"):
Expand Down Expand Up @@ -149,10 +149,8 @@ class BaseService(object):
The max seatch iter.
record_step: int
The record step.
debug_level: int
The debug level
verbose: str
The verbose level.
The verbose level
"""

def __init__(
Expand All @@ -170,15 +168,13 @@ def __init__(
):
self._workspace = workspace
tasks = tasks or [GYMObject.ENV + ":0", GYMObject.AGENT + ":0"]
if not verbose:
verbose = "debug" if debug_level > 0 else "info"
verbose = verbose or "info"
debug_level = int(verbose.split(":")[1]) if verbose.startswith("debug:") else 0
self._logger = msc_utils.create_file_logger(verbose, self._workspace.relpath("SERVICE_LOG"))

def _create_workers(config: dict, obj_type: str) -> List[BaseWorker]:
def _create_workers(config: dict, obj_type: str) -> List[BaseGymWorker]:
if "debug_level" not in config:
config["debug_level"] = debug_level
if "verbose" not in config:
config["verbose"] = verbose
if "logger" not in config:
config["logger"] = self._logger
return [
Expand All @@ -192,9 +188,7 @@ def _create_workers(config: dict, obj_type: str) -> List[BaseWorker]:
self._max_iter = max_iter
self._record_step = record_step
self._debug_level = debug_level
self._logger.info(
msc_utils.msg_block("SERVICE.SETUP({})".format(self.service_type), self.setup())
)
self._logger.info(msc_utils.msg_block(self.service_mark("SETUP"), self.setup()))

def setup(self) -> dict:
"""Setup the tool
Expand Down Expand Up @@ -242,8 +236,8 @@ def reset(self):
self._task_id, self._states = 0, []
self._iter_done = False
self._logger.info("SERVICE Reset %d/%d th iter", self._iter_id, self._max_iter)
self.execute(GYMObject.AGENT, GYMAction.RESET)
self.execute(GYMObject.ENV, GYMAction.RESET)
self.execute(GYMObject.AGENT, GYMAction.RESET)

def learn(self):
self.execute(GYMObject.AGENT, GYMAction.LEARN)
Expand Down Expand Up @@ -387,9 +381,9 @@ def _process_request(self, msg_key: str) -> dict:
workers = {w.worker_id: w for w in self._get_workers(obj_type)}
requests = self._wait_request(msg_key)
if act_type in (GYMAction.INIT, GYMAction.RESET):
mark = "I[{}/{}] {}.{}".format(self._iter_id, self._max_iter, obj_type, act_type)
mark = "Iter[{}/{}] {}.{}".format(self._iter_id, self._max_iter, obj_type, act_type)
else:
mark = "I[{}/{}].T[{}/{}] {}.{}".format(
mark = "Iter[{}/{}] Task[{}/{}] {}.{}".format(
self._iter_id, self._max_iter, self._task_id, self._max_task, obj_type, act_type
)
requests = {int(k): v for k, v in requests.items()}
Expand All @@ -400,7 +394,7 @@ def _process_request(self, msg_key: str) -> dict:
"requests": {workers[w].name: r for w, r in requests.items()},
"responses": {workers[w].name: r for w, r in responses.items()},
}
self._logger.info(msc_utils.msg_table(mark, info))
self._logger.info(msc_utils.msg_block(mark, info, symbol="="))
return responses

def _process_response(self, msg_key: str, response: dict):
Expand Down Expand Up @@ -464,7 +458,7 @@ def _from_msg_key(self, msg_key: str) -> Tuple[str, str]:

return msg_key.split("-s-")

def _get_workers(self, obj_type: str) -> List[BaseWorker]:
def _get_workers(self, obj_type: str) -> List[BaseGymWorker]:
"""Get workers according to obj_type
Parameters
Expand Down Expand Up @@ -519,6 +513,22 @@ def _get_world_ids(self, obj_type: str) -> List[int]:
return self._agent_world_ids
return []

def service_mark(self, msg: Any) -> str:
"""Mark the message with service info
Parameters
-------
msg: str
The message
Returns
-------
msg: str
The message with mark.
"""

return "SERIVCE({}) {}".format(self.service_type, msg)

@property
def done(self):
return self._done
Expand Down
Loading

0 comments on commit 3b25697

Please sign in to comment.