Skip to content

Commit

Permalink
dsitributed training pipeline draft
Browse files Browse the repository at this point in the history
  • Loading branch information
ysqyang committed Dec 27, 2021
1 parent 8bdec10 commit 9ebabaa
Show file tree
Hide file tree
Showing 24 changed files with 703 additions and 614 deletions.
17 changes: 9 additions & 8 deletions maro/rl/workflows/grad_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,27 @@
from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module

if __name__ == "__main__":
# TODO: WORKERID in docker compose script.
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "policy_func_dict")
worker_id = f"GRAD_WORKER.{from_env('WORKERID')}"
num_hosts = from_env("NUMHOSTS") if from_env("POLICYMANAGERTYPE") == "distributed" else 0
# TODO: WORKER_ID in docker compose script.
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIO_PATH")), "policy_func_dict")
worker_id = f"GRAD_WORKER.{from_env('WORKER_ID')}"
num_hosts = from_env("NUM_HOSTS") if from_env("POLICY_MANAGER_TYPE") == "distributed" else 0
max_cached_policies = from_env("MAXCACHED", required=False, default=10)

group = from_env("POLICYGROUP", required=False, default="learn")
group = from_env("POLICY_GROUP", required=False, default="learn")
policy_dict = {}
active_policies = []
if num_hosts == 0:
# no remote nodes for policy hosts
num_hosts = len(policy_func_dict)

logger = get_logger(from_env("LOG_PATH", required=False, default=os.getcwd()), from_env("JOB"), worker_id)

peers = {"policy_manager": 1, "policy_host": num_hosts, "task_queue": 1}
proxy = Proxy(
group, "grad_worker", peers, component_name=worker_id,
redis_address=(from_env("REDISHOST"), from_env("REDISPORT")),
group, "grad_worker", peers, component_name=worker_id, logger=logger,
redis_address=(from_env("REDIS_HOST"), from_env("REDIS_PORT")),
max_peer_discovery_retries=50
)
logger = get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), worker_id)

for msg in proxy.receive():
if msg.tag == MsgTag.EXIT:
Expand Down
35 changes: 18 additions & 17 deletions maro/rl/workflows/policy_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,30 @@
from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module

if __name__ == "__main__":
host_id = f"POLICY_HOST.{from_env('HOSTID')}"
host_id = f"POLICY_HOST.{from_env('HOST_ID')}"
peers = {"policy_manager": 1}
data_parallelism = from_env("DATAPARALLELISM", required=False, default=1)
data_parallelism = from_env("DATA_PARALLELISM", required=False, default=1)
if data_parallelism > 1:
peers["grad_worker"] = data_parallelism
peers["task_queue"] = 1

policy_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "policy_func_dict")
group = from_env("POLICYGROUP")
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIO_PATH")), "policy_func_dict")
group = from_env("POLICY_GROUP")
policy_dict, checkpoint_path = {}, {}

logger = get_logger(from_env("LOG_PATH", required=False, default=os.getcwd()), from_env("JOB"), host_id)

proxy = Proxy(
group, "policy_host", peers,
component_name=host_id,
redis_address=(from_env("REDISHOST"), from_env("REDISPORT")),
logger=logger,
redis_address=(from_env("REDIS_HOST"), from_env("REDIS_PORT")),
max_peer_discovery_retries=50
)
load_policy_dir = from_env("LOADDIR", required=False, default=None)
checkpoint_dir = from_env("CHECKPOINTDIR", required=False, default=None)
if checkpoint_dir:
os.makedirs(checkpoint_dir, exist_ok=True)

logger = get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), host_id)
load_path = from_env("LOAD_PATH", required=False, default=None)
checkpoint_path = from_env("CHECKPOINT_PATH", required=False, default=None)
if checkpoint_path:
os.makedirs(checkpoint_path, exist_ok=True)

for msg in proxy.receive():
if msg.tag == MsgTag.EXIT:
Expand All @@ -41,9 +42,8 @@
elif msg.tag == MsgTag.INIT_POLICIES:
for id_ in msg.body[MsgKey.POLICY_IDS]:
policy_dict[id_] = policy_func_dict[id_](id_)
checkpoint_path[id_] = os.path.join(checkpoint_dir, id_) if checkpoint_dir else None
if load_policy_dir:
path = os.path.join(load_policy_dir, id_)
if load_path:
path = os.path.join(load_path, id_)
if os.path.exists(path):
policy_dict[id_].load(path)
logger.info(f"Loaded policy {id_} from {path}")
Expand Down Expand Up @@ -71,9 +71,10 @@
logger.info("learning from batch")
policy_dict[id_].learn(info)

if checkpoint_path[id_]:
policy_dict[id_].save(checkpoint_path[id_])
logger.info(f"Saved policy {id_} to {checkpoint_path[id_]}")
if checkpoint_path:
save_path = os.path.join(checkpoint_path, id_)
policy_dict[id_].save(save_path)
logger.info(f"Saved policy {id_} to {save_path}")

msg_body = {
MsgKey.POLICY_STATE: {name: policy_dict[name].get_state() for name in msg.body[MsgKey.ROLLOUT_INFO]}
Expand Down
66 changes: 0 additions & 66 deletions maro/rl/workflows/policy_manager.py

This file was deleted.

41 changes: 0 additions & 41 deletions maro/rl/workflows/rollout.py

This file was deleted.

45 changes: 0 additions & 45 deletions maro/rl/workflows/rollout_manager.py

This file was deleted.

14 changes: 7 additions & 7 deletions maro/rl/workflows/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module

if __name__ == "__main__":
num_hosts = from_env("NUMHOSTS", required=False, default=0)
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "policy_func_dict")
data_parallelism = from_env("DATAPARALLELISM", required=False, default=1)
num_hosts = from_env("NUM_HOSTS", required=False, default=0)
policy_func_dict = getattr(get_scenario_module(from_env("SCENARIO_PATH")), "policy_func_dict")
data_parallelism = from_env("DATA_PARALLELISM", required=False, default=1)
worker_id_list = [f"GRAD_WORKER.{i}" for i in range(data_parallelism)]

task_queue(
worker_id_list,
num_hosts,
len(policy_func_dict),
single_task_limit=from_env("SINGLETASKLIMIT", required=False, default=0.5),
group=from_env("POLICYGROUP", required=False, default="learn"),
group=from_env("POLICY_GROUP", required=False, default="learn"),
proxy_kwargs={
"redis_address": (
from_env("REDISHOST", required=False, default="maro-redis"),
from_env("REDISPORT", required=False, default=6379)),
from_env("REDIS_HOST", required=False, default="maro-redis"),
from_env("REDIS_PORT", required=False, default=6379)),
"max_peer_discovery_retries": 50
},
logger=get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), "TASK_QUEUE")
logger=get_logger(from_env("LOG_PATH", required=False, default=os.getcwd()), from_env("JOB"), "TASK_QUEUE")
)
Empty file.
65 changes: 65 additions & 0 deletions maro/rl_v3/distributed/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Callable

import zmq
from tornado.ioloop import IOLoop
from zmq import Context
from zmq.eventloop.zmqstream import ZMQStream

from maro.rl_v3.distributed.utils import bytes_to_pyobj, string_to_bytes


class Dispatcher(object):
def __init__(self, host: str, num_workers: int, frontend_port: int = 10000, backend_port: int = 10001, hash_fn: Callable = hash):
# ZMQ sockets and streams
self._context = Context.instance()
self._req_socket = self._context.socket(zmq.ROUTER)
self._req_socket.bind(f"tcp://{host}:{frontend_port}")
self._req_receiver = ZMQStream(self._req_socket)
self._route_socket = self._context.socket(zmq.ROUTER)
self._route_socket.bind(f"tcp://{host}:{backend_port}")
self._router = ZMQStream(self._route_socket)

self._event_loop = IOLoop.current()

# register handlers
self._req_receiver.on_recv(self._route_request_to_compute_node)
self._req_receiver.on_send(self.log_send_result)
self._router.on_recv(self._send_result_to_requester)
self._router.on_send(self.log_route_request)

# bookkeeping
self._num_workers = num_workers
self._hash_fn = hash_fn
self._ops2node = {}

def _route_request_to_compute_node(self, msg):
ops_name, _, req = msg
print(f"Received request from {ops_name}")
if ops_name not in self._ops2node:
self._ops2node[ops_name] = self._hash_fn(ops_name) % self._num_workers
print(f"Placing {ops_name} at worker node {self._ops2node[ops_name]}")
worker_id = f'worker.{self._ops2node[ops_name]}'
self._router.send_multipart([string_to_bytes(worker_id), b"", ops_name, b"", req])

def _send_result_to_requester(self, msg):
worker_id, _, result = msg[:3]
if result != b"READY":
self._req_receiver.send_multipart(msg[2:])
else:
print(f"{worker_id} ready")

def start(self):
self._event_loop.start()

def stop(self):
self._event_loop.stop()

@staticmethod
def log_route_request(msg, status):
worker_id, _, ops_name, _, req = msg
req = bytes_to_pyobj(req)
print(f"Routed {ops_name}'s request {req['func']} to worker node {worker_id}")

@staticmethod
def log_send_result(msg, status):
print(f"Returned result for {msg[0]}")
40 changes: 40 additions & 0 deletions maro/rl_v3/distributed/remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Tuple

import zmq
from zmq.asyncio import Context

from .utils import bytes_to_pyobj, pyobj_to_bytes, string_to_bytes


def remote_method(ops_name, func_name: str, dispatcher_address):
async def remote_call(*args, **kwargs):
req = {"func": func_name, "args": args, "kwargs": kwargs}
context = Context.instance()
sock = context.socket(zmq.REQ)
sock.identity = string_to_bytes(ops_name)
sock.connect(dispatcher_address)
sock.send(pyobj_to_bytes(req))
print(f"sent request {func_name} for {ops_name}")
result = bytes_to_pyobj(await sock.recv())
print(f"result for request {func_name} for {ops_name}: {result}")
sock.close()
return result

return remote_call


class RemoteOps(object):
def __init__(self, ops_name, dispatcher_address: Tuple[str, int]):
self._ops_name = ops_name
# self._functions = {name for name, _ in inspect.getmembers(train_op_cls, lambda attr: inspect.isfunction(attr))}
host, port = dispatcher_address
self._dispatcher_address = f"tcp://{host}:{port}"

def __getattribute__(self, attr_name: str):
# Ignore methods that belong to the parent class
try:
return super().__getattribute__(attr_name)
except AttributeError:
pass

return remote_method(self._ops_name, attr_name, self._dispatcher_address)
Loading

0 comments on commit 9ebabaa

Please sign in to comment.