Skip to content

Commit

Permalink
[DRAFT] distributed training pipeline based on RL Toolkit V3 (#450)
Browse files Browse the repository at this point in the history
* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* dsitributed training pipeline draft

* added temporary test files for review purposes

* Several code style refinements (#451)

* Polish rl_v3/utils/

* Polish rl_v3/distributed/

* Polish rl_v3/policy_trainer/abs_trainer.py

* fixed merge conflicts

* unified sync and async interfaces

* refactored rl_v3; refinement in progress

* Finish the runnable pipeline under new design

* Remove outdated files; refine class names; optimize imports;

* Lint

* Minor maddpg related refinement

* Lint

Co-authored-by: Default <huo53926@126.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
  • Loading branch information
5 people authored Jan 5, 2022
1 parent 8e4ad49 commit f80dcc3
Show file tree
Hide file tree
Showing 51 changed files with 2,054 additions and 1,675 deletions.
17 changes: 9 additions & 8 deletions maro/rl/workflows/grad_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,27 @@
from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module

if __name__ == "__main__":
# TODO: WORKERID in docker compose script.
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "policy_func_dict")
worker_id = f"GRAD_WORKER.{from_env('WORKERID')}"
num_hosts = from_env("NUMHOSTS") if from_env("POLICYMANAGERTYPE") == "distributed" else 0
# TODO: WORKER_ID in docker compose script.
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIO_PATH")), "policy_func_dict")
worker_id = f"GRAD_WORKER.{from_env('WORKER_ID')}"
num_hosts = from_env("NUM_HOSTS") if from_env("POLICY_MANAGER_TYPE") == "distributed" else 0
max_cached_policies = from_env("MAXCACHED", required=False, default=10)

group = from_env("POLICYGROUP", required=False, default="learn")
group = from_env("POLICY_GROUP", required=False, default="learn")
policy_dict = {}
active_policies = []
if num_hosts == 0:
# no remote nodes for policy hosts
num_hosts = len(policy_func_dict)

logger = get_logger(from_env("LOG_PATH", required=False, default=os.getcwd()), from_env("JOB"), worker_id)

peers = {"policy_manager": 1, "policy_host": num_hosts, "task_queue": 1}
proxy = Proxy(
group, "grad_worker", peers, component_name=worker_id,
redis_address=(from_env("REDISHOST"), from_env("REDISPORT")),
group, "grad_worker", peers, component_name=worker_id, logger=logger,
redis_address=(from_env("REDIS_HOST"), from_env("REDIS_PORT")),
max_peer_discovery_retries=50
)
logger = get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), worker_id)

for msg in proxy.receive():
if msg.tag == MsgTag.EXIT:
Expand Down
35 changes: 18 additions & 17 deletions maro/rl/workflows/policy_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,30 @@
from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module

if __name__ == "__main__":
host_id = f"POLICY_HOST.{from_env('HOSTID')}"
host_id = f"POLICY_HOST.{from_env('HOST_ID')}"
peers = {"policy_manager": 1}
data_parallelism = from_env("DATAPARALLELISM", required=False, default=1)
data_parallelism = from_env("DATA_PARALLELISM", required=False, default=1)
if data_parallelism > 1:
peers["grad_worker"] = data_parallelism
peers["task_queue"] = 1

policy_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "policy_func_dict")
group = from_env("POLICYGROUP")
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIO_PATH")), "policy_func_dict")
group = from_env("POLICY_GROUP")
policy_dict, checkpoint_path = {}, {}

logger = get_logger(from_env("LOG_PATH", required=False, default=os.getcwd()), from_env("JOB"), host_id)

proxy = Proxy(
group, "policy_host", peers,
component_name=host_id,
redis_address=(from_env("REDISHOST"), from_env("REDISPORT")),
logger=logger,
redis_address=(from_env("REDIS_HOST"), from_env("REDIS_PORT")),
max_peer_discovery_retries=50
)
load_policy_dir = from_env("LOADDIR", required=False, default=None)
checkpoint_dir = from_env("CHECKPOINTDIR", required=False, default=None)
if checkpoint_dir:
os.makedirs(checkpoint_dir, exist_ok=True)

logger = get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), host_id)
load_path = from_env("LOAD_PATH", required=False, default=None)
checkpoint_path = from_env("CHECKPOINT_PATH", required=False, default=None)
if checkpoint_path:
os.makedirs(checkpoint_path, exist_ok=True)

for msg in proxy.receive():
if msg.tag == MsgTag.EXIT:
Expand All @@ -41,9 +42,8 @@
elif msg.tag == MsgTag.INIT_POLICIES:
for id_ in msg.body[MsgKey.POLICY_IDS]:
policy_dict[id_] = policy_func_dict[id_](id_)
checkpoint_path[id_] = os.path.join(checkpoint_dir, id_) if checkpoint_dir else None
if load_policy_dir:
path = os.path.join(load_policy_dir, id_)
if load_path:
path = os.path.join(load_path, id_)
if os.path.exists(path):
policy_dict[id_].load(path)
logger.info(f"Loaded policy {id_} from {path}")
Expand Down Expand Up @@ -71,9 +71,10 @@
logger.info("learning from batch")
policy_dict[id_].learn(info)

if checkpoint_path[id_]:
policy_dict[id_].save(checkpoint_path[id_])
logger.info(f"Saved policy {id_} to {checkpoint_path[id_]}")
if checkpoint_path:
save_path = os.path.join(checkpoint_path, id_)
policy_dict[id_].save(save_path)
logger.info(f"Saved policy {id_} to {save_path}")

msg_body = {
MsgKey.POLICY_STATE: {name: policy_dict[name].get_state() for name in msg.body[MsgKey.ROLLOUT_INFO]}
Expand Down
66 changes: 0 additions & 66 deletions maro/rl/workflows/policy_manager.py

This file was deleted.

41 changes: 0 additions & 41 deletions maro/rl/workflows/rollout.py

This file was deleted.

45 changes: 0 additions & 45 deletions maro/rl/workflows/rollout_manager.py

This file was deleted.

14 changes: 7 additions & 7 deletions maro/rl/workflows/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module

if __name__ == "__main__":
num_hosts = from_env("NUMHOSTS", required=False, default=0)
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "policy_func_dict")
data_parallelism = from_env("DATAPARALLELISM", required=False, default=1)
num_hosts = from_env("NUM_HOSTS", required=False, default=0)
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIO_PATH")), "policy_func_dict")
data_parallelism = from_env("DATA_PARALLELISM", required=False, default=1)
worker_id_list = [f"GRAD_WORKER.{i}" for i in range(data_parallelism)]

task_queue(
worker_id_list,
num_hosts,
len(policy_func_dict),
single_task_limit=from_env("SINGLETASKLIMIT", required=False, default=0.5),
group=from_env("POLICYGROUP", required=False, default="learn"),
group=from_env("POLICY_GROUP", required=False, default="learn"),
proxy_kwargs={
"redis_address": (
from_env("REDISHOST", required=False, default="maro-redis"),
from_env("REDISPORT", required=False, default=6379)),
from_env("REDIS_HOST", required=False, default="maro-redis"),
from_env("REDIS_PORT", required=False, default=6379)),
"max_peer_discovery_retries": 50
},
logger=get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), "TASK_QUEUE")
logger=get_logger(from_env("LOG_PATH", required=False, default=os.getcwd()), from_env("JOB"), "TASK_QUEUE")
)
4 changes: 2 additions & 2 deletions maro/rl_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .workflow import run_workflow_centralized_mode
# from .workflow import run_workflow_centralized_mode

__all__ = ["run_workflow_centralized_mode"]
# __all__ = ["run_workflow_centralized_mode"]
Empty file.
72 changes: 72 additions & 0 deletions maro/rl_v3/distributed/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Callable, Dict

import zmq
from tornado.ioloop import IOLoop
from zmq import Context
from zmq.eventloop.zmqstream import ZMQStream

from maro.rl_v3.utils.distributed import bytes_to_pyobj, string_to_bytes


class Dispatcher(object):
def __init__(
self,
host: str,
num_workers: int,
frontend_port: int = 10000,
backend_port: int = 10001,
hash_fn: Callable[[str], int] = hash
) -> None:
# ZMQ sockets and streams
self._context = Context.instance()
self._req_socket = self._context.socket(zmq.ROUTER)
self._req_socket.bind(f"tcp://{host}:{frontend_port}")
self._req_receiver = ZMQStream(self._req_socket)
self._route_socket = self._context.socket(zmq.ROUTER)
self._route_socket.bind(f"tcp://{host}:{backend_port}")
self._router = ZMQStream(self._route_socket)

self._event_loop = IOLoop.current()

# register handlers
self._req_receiver.on_recv(self._route_request_to_compute_node)
self._req_receiver.on_send(self.log_send_result)
self._router.on_recv(self._send_result_to_requester)
self._router.on_send(self.log_route_request)

# bookkeeping
self._num_workers = num_workers
self._hash_fn = hash_fn
self._ops2node: Dict[str, int] = {}

def _route_request_to_compute_node(self, msg: list) -> None:
ops_name, _, req = msg
print(f"Received request from {ops_name}")
if ops_name not in self._ops2node:
self._ops2node[ops_name] = self._hash_fn(ops_name) % self._num_workers
print(f"Placing {ops_name} at worker node {self._ops2node[ops_name]}")
worker_id = f'worker.{self._ops2node[ops_name]}'
self._router.send_multipart([string_to_bytes(worker_id), b"", ops_name, b"", req])

def _send_result_to_requester(self, msg: list) -> None:
worker_id, _, result = msg[:3]
if result != b"READY":
self._req_receiver.send_multipart(msg[2:])
else:
print(f"{worker_id} ready")

def start(self) -> None:
self._event_loop.start()

def stop(self) -> None:
self._event_loop.stop()

@staticmethod
def log_route_request(msg: list, status: object) -> None:
worker_id, _, ops_name, _, req = msg
req = bytes_to_pyobj(req)
print(f"Routed {ops_name}'s request {req['func']} to worker node {worker_id}")

@staticmethod
def log_send_result(msg: list, status: object) -> None:
print(f"Returned result for {msg[0]}")
1 change: 0 additions & 1 deletion maro/rl_v3/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .env_sampler import AbsAgentWrapper, AbsEnvSampler, CacheElement, ExpElement, SimpleAgentWrapper
from .trainer_manager import AbsTrainerManager, SimpleTrainerManager

__all__ = [
"AbsAgentWrapper", "AbsEnvSampler", "CacheElement", "ExpElement", "SimpleAgentWrapper",
Expand Down
2 changes: 1 addition & 1 deletion maro/rl_v3/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
"MultiQNet",
"ContinuousPolicyNet", "DiscretePolicyNet", "PolicyNet",
"ContinuousQNet", "DiscreteQNet", "QNet",
"VNet"
"VNet",
]
2 changes: 1 addition & 1 deletion maro/rl_v3/model/multi_q_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from maro.rl_v3.utils import SHAPE_CHECK_FLAG, match_shape
from maro.rl_v3.utils import match_shape, SHAPE_CHECK_FLAG

from .abs_net import AbsNet

Expand Down
2 changes: 1 addition & 1 deletion maro/rl_v3/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
__all__ = [
"AbsPolicy", "DummyPolicy", "RLPolicy", "RuleBasedPolicy",
"ContinuousRLPolicy",
"DiscretePolicyGradient", "DiscreteRLPolicy", "ValueBasedPolicy"
"DiscretePolicyGradient", "DiscreteRLPolicy", "ValueBasedPolicy",
]
Loading

0 comments on commit f80dcc3

Please sign in to comment.