-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DRAFT] distributed training pipeline based on RL Toolkit V3 (#450)
* Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * dsitributed training pipeline draft * added temporary test files for review purposes * Several code style refinements (#451) * Polish rl_v3/utils/ * Polish rl_v3/distributed/ * Polish rl_v3/policy_trainer/abs_trainer.py * fixed merge conflicts * unified sync and async interfaces * refactored rl_v3; refinement in progress * Finish the runnable pipeline under new design * Remove outdated files; refine class names; optimize imports; * Lint * Minor maddpg related refinement * Lint Co-authored-by: Default <huo53926@126.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com>
- Loading branch information
1 parent
8e4ad49
commit f80dcc3
Showing
51 changed files
with
2,054 additions
and
1,675 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .workflow import run_workflow_centralized_mode | ||
# from .workflow import run_workflow_centralized_mode | ||
|
||
__all__ = ["run_workflow_centralized_mode"] | ||
# __all__ = ["run_workflow_centralized_mode"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from typing import Callable, Dict | ||
|
||
import zmq | ||
from tornado.ioloop import IOLoop | ||
from zmq import Context | ||
from zmq.eventloop.zmqstream import ZMQStream | ||
|
||
from maro.rl_v3.utils.distributed import bytes_to_pyobj, string_to_bytes | ||
|
||
|
||
class Dispatcher(object): | ||
def __init__( | ||
self, | ||
host: str, | ||
num_workers: int, | ||
frontend_port: int = 10000, | ||
backend_port: int = 10001, | ||
hash_fn: Callable[[str], int] = hash | ||
) -> None: | ||
# ZMQ sockets and streams | ||
self._context = Context.instance() | ||
self._req_socket = self._context.socket(zmq.ROUTER) | ||
self._req_socket.bind(f"tcp://{host}:{frontend_port}") | ||
self._req_receiver = ZMQStream(self._req_socket) | ||
self._route_socket = self._context.socket(zmq.ROUTER) | ||
self._route_socket.bind(f"tcp://{host}:{backend_port}") | ||
self._router = ZMQStream(self._route_socket) | ||
|
||
self._event_loop = IOLoop.current() | ||
|
||
# register handlers | ||
self._req_receiver.on_recv(self._route_request_to_compute_node) | ||
self._req_receiver.on_send(self.log_send_result) | ||
self._router.on_recv(self._send_result_to_requester) | ||
self._router.on_send(self.log_route_request) | ||
|
||
# bookkeeping | ||
self._num_workers = num_workers | ||
self._hash_fn = hash_fn | ||
self._ops2node: Dict[str, int] = {} | ||
|
||
def _route_request_to_compute_node(self, msg: list) -> None: | ||
ops_name, _, req = msg | ||
print(f"Received request from {ops_name}") | ||
if ops_name not in self._ops2node: | ||
self._ops2node[ops_name] = self._hash_fn(ops_name) % self._num_workers | ||
print(f"Placing {ops_name} at worker node {self._ops2node[ops_name]}") | ||
worker_id = f'worker.{self._ops2node[ops_name]}' | ||
self._router.send_multipart([string_to_bytes(worker_id), b"", ops_name, b"", req]) | ||
|
||
def _send_result_to_requester(self, msg: list) -> None: | ||
worker_id, _, result = msg[:3] | ||
if result != b"READY": | ||
self._req_receiver.send_multipart(msg[2:]) | ||
else: | ||
print(f"{worker_id} ready") | ||
|
||
def start(self) -> None: | ||
self._event_loop.start() | ||
|
||
def stop(self) -> None: | ||
self._event_loop.stop() | ||
|
||
@staticmethod | ||
def log_route_request(msg: list, status: object) -> None: | ||
worker_id, _, ops_name, _, req = msg | ||
req = bytes_to_pyobj(req) | ||
print(f"Routed {ops_name}'s request {req['func']} to worker node {worker_id}") | ||
|
||
@staticmethod | ||
def log_send_result(msg: list, status: object) -> None: | ||
print(f"Returned result for {msg[0]}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.