Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MSC][M5.2] Enable quantize && prune with gym by wrapper #16702

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions gallery/how_to/work_with_msc/using_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
parser.add_argument("--calibrate_iter", type=int, default=100, help="The iter for calibration")
parser.add_argument("--train_batch", type=int, default=32, help="The batch size for train")
parser.add_argument("--train_iter", type=int, default=200, help="The iter for train")
parser.add_argument("--train_epoch", type=int, default=5, help="The epoch for train")
parser.add_argument("--train_epoch", type=int, default=100, help="The epoch for train")
parser.add_argument(
"--verbose", type=str, default="info", help="The verbose level, info|debug:1,2,3|critical"
)
args = parser.parse_args()


Expand Down Expand Up @@ -86,7 +89,7 @@ def get_config(calib_loader, train_loader):
dataset=dataset,
tools=tools,
skip_config={"all": "check"},
verbose="info",
verbose=args.verbose,
)


Expand Down Expand Up @@ -130,3 +133,7 @@ def _get_train_datas():
model.compile()
acc = eval_model(model, testloader, max_iter=args.test_iter)
print("Compiled acc: " + str(acc))

# export the model
path = model.export()
print("Export model to " + str(path))
49 changes: 30 additions & 19 deletions python/tvm/contrib/msc/core/gym/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import logging
from typing import Dict, Any, List, Tuple
from tvm.contrib.msc.core.gym.namespace import GYMObject
from tvm.contrib.msc.core import utils as msc_utils


Expand All @@ -37,8 +38,6 @@ class BaseAgent(object):
The extra options for the agent.
debug_level: int
The debug level.
verbose: str
The verbose level.
logger: logging.Logger
The logger
"""
Expand All @@ -50,23 +49,15 @@ def __init__(
executors: dict,
options: dict = None,
debug_level: int = 0,
verbose: str = None,
logger: logging.Logger = None,
):
self._name = name
self._workspace = workspace
self._executors = self._parse_executors(msc_utils.copy_dict(executors))
self._options = options or {}
self._debug_level = debug_level
if logger:
self._logger = logger
else:
if not verbose:
verbose = "debug" if debug_level > 0 else "info"
self._logger = msc_utils.create_file_logger(verbose, workspace.relpath("AGENT_LOG"))
self._logger.info(
msc_utils.msg_block("AGENT.SETUP({})".format(self.agent_type()), self.setup())
)
self._logger = logger or msc_utils.get_global_logger()
self._logger.info(msc_utils.msg_block(self.agent_mark("SETUP"), self.setup()))

def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, dict]]:
"""Parse the executors
Expand All @@ -85,9 +76,12 @@ def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, di
executors = {}
for name, raw_config in executors_dict.items():
method_type = (
raw_config.pop("method_type") if "method_type" in raw_config else "agent.default"
raw_config.pop("method_type") if "method_type" in raw_config else "default"
)
method_cls = msc_utils.get_registered_gym_method(GYMObject.AGENT, method_type)
assert method_cls, "Can not find method cls for {}:{}".format(
GYMObject.AGENT, method_type
)
method_cls = msc_utils.get_registered_gym_method(method_type)
assert "method" in raw_config, "method should be given to find agent method"
method_name, method = raw_config.pop("method"), None
if hasattr(method_cls, method_name):
Expand Down Expand Up @@ -244,7 +238,7 @@ def learn(self):
The learned rewards.
"""

self._logger.debug(msc_utils.msg_block("AGENT.LEARN", self._knowledge))
self._logger.debug(msc_utils.msg_block(self.agent_mark("KNOWLEDEG"), self._knowledge))
return self._learn()

def _learn(self):
Expand Down Expand Up @@ -306,9 +300,26 @@ def _evaluate(self, reward: dict) -> float:

return self._execute("evaluate", self._baseline, reward)

@classmethod
def agent_type(cls):
return "base"
def agent_mark(self, msg: Any) -> str:
"""Mark the message with agent info

Parameters
-------
msg: str
The message

Returns
-------
msg: str
The message with mark.
"""

return "AGENT({}) {}".format(self.role_type(), msg)

msc_utils.register_gym_agent(BaseAgent)
@classmethod
def role(cls):
return GYMObject.AGENT

@classmethod
def role_type(cls):
return "base"
11 changes: 7 additions & 4 deletions python/tvm/contrib/msc/core/gym/agent/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
"""tvm.contrib.msc.core.gym.agent.method"""

from typing import Any
from tvm.contrib.msc.core.gym.namespace import GYMObject
from tvm.contrib.msc.core import utils as msc_utils


@msc_utils.register_gym_method
class AgentMethod(object):
"""Default prune method"""

Expand Down Expand Up @@ -73,8 +75,9 @@ def evaluate_by_thresh(cls, agent: Any, baseline: dict, reward: dict, thresh: fl
return reward["reward"]

@classmethod
def method_type(cls):
return "agent.default"

def role(cls):
return GYMObject.AGENT

msc_utils.register_gym_method(AgentMethod)
@classmethod
def method_type(cls):
return "default"
11 changes: 5 additions & 6 deletions python/tvm/contrib/msc/core/gym/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ def setup(self) -> dict:
return super().setup()

@classmethod
def agent_type(cls):
def role_type(cls):
return "search.base"


@msc_utils.register_gym_object
class GridSearchAgent(BaseSearchAgent):
"""GridSearch agent"""

Expand Down Expand Up @@ -92,10 +93,11 @@ def _learn(self):
return best_actions, best_rewards

@classmethod
def agent_type(cls):
def role_type(cls):
return "search.grid"


@msc_utils.register_gym_object
class BinarySearchAgent(BaseSearchAgent):
"""BinarySearch agent"""

Expand Down Expand Up @@ -173,8 +175,5 @@ def _learn(self):
return actions, rewards

@classmethod
def agent_type(cls):
def role_type(cls):
return "search.binary"


msc_utils.register_gym_agent(GridSearchAgent)
12 changes: 5 additions & 7 deletions python/tvm/contrib/msc/core/gym/control/configer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def update(self, raw_config: dict) -> dict:
raise NotImplementedError("update is not implemented in BaseConfiger")


@msc_utils.register_gym_configer
class DefaultConfiger(BaseConfiger):
"""Default configer for gym"""

Expand All @@ -67,10 +68,10 @@ def update(self, raw_config: dict) -> dict:

config = msc_utils.copy_dict(raw_config)
assert "env" in config and "agent" in config, "env and agent should be given to run gym"
if "env_type" not in config["env"]:
config["env"]["env_type"] = self._stage + ".default"
if "agent_type" not in config["agent"]:
config["agent"]["agent_type"] = "search.grid"
if "role_type" not in config["env"]:
config["env"]["role_type"] = self._stage + ".default"
if "role_type" not in config["agent"]:
config["agent"]["role_type"] = "search.grid"
if "executors" not in config["env"]:
config["env"]["executors"] = {}
# update executors
Expand All @@ -92,6 +93,3 @@ def update(self, raw_config: dict) -> dict:
@classmethod
def config_type(cls):
return "default"


msc_utils.register_gym_configer(DefaultConfiger)
6 changes: 2 additions & 4 deletions python/tvm/contrib/msc/core/gym/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
"""tvm.contrib.msc.core.gym.control.controller"""

from typing import Dict, Any
from tvm.contrib.msc.core.gym.namespace import GYMObject, GYMAction
from tvm.contrib.msc.core import utils as msc_utils
from .service import MainService, NodeService
from .namespace import GYMObject, GYMAction


class BaseController(object):
Expand Down Expand Up @@ -98,10 +98,8 @@ def create_controller(stage: str, config: dict, extra_config: dict = None):
return controller_cls(msc_utils.get_gym_dir(), config)


@msc_utils.register_gym_controller
class DefaultController(BaseController):
@classmethod
def control_type(cls):
return "default"


msc_utils.register_gym_controller(DefaultController)
46 changes: 28 additions & 18 deletions python/tvm/contrib/msc/core/gym/control/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import queue
import numpy as np

from tvm.contrib.msc.core.gym.namespace import GYMObject, GYMAction
from tvm.contrib.msc.core import utils as msc_utils
from .worker import BaseWorker, WorkerFactory
from .namespace import GYMObject, GYMAction
from .worker import BaseGymWorker, WorkerFactory


def _send_message(msg_queue: queue.Queue, header: str, body: dict, header_type: str = "message"):
Expand Down Expand Up @@ -149,10 +149,8 @@ class BaseService(object):
The max seatch iter.
record_step: int
The record step.
debug_level: int
The debug level
verbose: str
The verbose level.
The verbose level
"""

def __init__(
Expand All @@ -170,15 +168,13 @@ def __init__(
):
self._workspace = workspace
tasks = tasks or [GYMObject.ENV + ":0", GYMObject.AGENT + ":0"]
if not verbose:
verbose = "debug" if debug_level > 0 else "info"
verbose = verbose or "info"
debug_level = int(verbose.split(":")[1]) if verbose.startswith("debug:") else 0
self._logger = msc_utils.create_file_logger(verbose, self._workspace.relpath("SERVICE_LOG"))

def _create_workers(config: dict, obj_type: str) -> List[BaseWorker]:
def _create_workers(config: dict, obj_type: str) -> List[BaseGymWorker]:
if "debug_level" not in config:
config["debug_level"] = debug_level
if "verbose" not in config:
config["verbose"] = verbose
if "logger" not in config:
config["logger"] = self._logger
return [
Expand All @@ -192,9 +188,7 @@ def _create_workers(config: dict, obj_type: str) -> List[BaseWorker]:
self._max_iter = max_iter
self._record_step = record_step
self._debug_level = debug_level
self._logger.info(
msc_utils.msg_block("SERVICE.SETUP({})".format(self.service_type), self.setup())
)
self._logger.info(msc_utils.msg_block(self.service_mark("SETUP"), self.setup()))

def setup(self) -> dict:
"""Setup the tool
Expand Down Expand Up @@ -242,8 +236,8 @@ def reset(self):
self._task_id, self._states = 0, []
self._iter_done = False
self._logger.info("SERVICE Reset %d/%d th iter", self._iter_id, self._max_iter)
self.execute(GYMObject.AGENT, GYMAction.RESET)
self.execute(GYMObject.ENV, GYMAction.RESET)
self.execute(GYMObject.AGENT, GYMAction.RESET)

def learn(self):
self.execute(GYMObject.AGENT, GYMAction.LEARN)
Expand Down Expand Up @@ -387,9 +381,9 @@ def _process_request(self, msg_key: str) -> dict:
workers = {w.worker_id: w for w in self._get_workers(obj_type)}
requests = self._wait_request(msg_key)
if act_type in (GYMAction.INIT, GYMAction.RESET):
mark = "I[{}/{}] {}.{}".format(self._iter_id, self._max_iter, obj_type, act_type)
mark = "Iter[{}/{}] {}.{}".format(self._iter_id, self._max_iter, obj_type, act_type)
else:
mark = "I[{}/{}].T[{}/{}] {}.{}".format(
mark = "Iter[{}/{}] Task[{}/{}] {}.{}".format(
self._iter_id, self._max_iter, self._task_id, self._max_task, obj_type, act_type
)
requests = {int(k): v for k, v in requests.items()}
Expand All @@ -400,7 +394,7 @@ def _process_request(self, msg_key: str) -> dict:
"requests": {workers[w].name: r for w, r in requests.items()},
"responses": {workers[w].name: r for w, r in responses.items()},
}
self._logger.info(msc_utils.msg_table(mark, info))
self._logger.info(msc_utils.msg_block(mark, info, symbol="="))
return responses

def _process_response(self, msg_key: str, response: dict):
Expand Down Expand Up @@ -464,7 +458,7 @@ def _from_msg_key(self, msg_key: str) -> Tuple[str, str]:

return msg_key.split("-s-")

def _get_workers(self, obj_type: str) -> List[BaseWorker]:
def _get_workers(self, obj_type: str) -> List[BaseGymWorker]:
"""Get workers according to obj_type

Parameters
Expand Down Expand Up @@ -519,6 +513,22 @@ def _get_world_ids(self, obj_type: str) -> List[int]:
return self._agent_world_ids
return []

def service_mark(self, msg: Any) -> str:
"""Mark the message with service info

Parameters
-------
msg: str
The message

Returns
-------
msg: str
The message with mark.
"""

return "SERIVCE({}) {}".format(self.service_type, msg)

@property
def done(self):
return self._done
Expand Down
Loading
Loading