-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
ysqyang
committed
Dec 27, 2021
1 parent
8bdec10
commit 9ebabaa
Showing
24 changed files
with
703 additions
and
614 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import Callable | ||
|
||
import zmq | ||
from tornado.ioloop import IOLoop | ||
from zmq import Context | ||
from zmq.eventloop.zmqstream import ZMQStream | ||
|
||
from maro.rl_v3.distributed.utils import bytes_to_pyobj, string_to_bytes | ||
|
||
|
||
class Dispatcher(object): | ||
def __init__(self, host: str, num_workers: int, frontend_port: int = 10000, backend_port: int = 10001, hash_fn: Callable = hash): | ||
# ZMQ sockets and streams | ||
self._context = Context.instance() | ||
self._req_socket = self._context.socket(zmq.ROUTER) | ||
self._req_socket.bind(f"tcp://{host}:{frontend_port}") | ||
self._req_receiver = ZMQStream(self._req_socket) | ||
self._route_socket = self._context.socket(zmq.ROUTER) | ||
self._route_socket.bind(f"tcp://{host}:{backend_port}") | ||
self._router = ZMQStream(self._route_socket) | ||
|
||
self._event_loop = IOLoop.current() | ||
|
||
# register handlers | ||
self._req_receiver.on_recv(self._route_request_to_compute_node) | ||
self._req_receiver.on_send(self.log_send_result) | ||
self._router.on_recv(self._send_result_to_requester) | ||
self._router.on_send(self.log_route_request) | ||
|
||
# bookkeeping | ||
self._num_workers = num_workers | ||
self._hash_fn = hash_fn | ||
self._ops2node = {} | ||
|
||
def _route_request_to_compute_node(self, msg): | ||
ops_name, _, req = msg | ||
print(f"Received request from {ops_name}") | ||
if ops_name not in self._ops2node: | ||
self._ops2node[ops_name] = self._hash_fn(ops_name) % self._num_workers | ||
print(f"Placing {ops_name} at worker node {self._ops2node[ops_name]}") | ||
worker_id = f'worker.{self._ops2node[ops_name]}' | ||
self._router.send_multipart([string_to_bytes(worker_id), b"", ops_name, b"", req]) | ||
|
||
def _send_result_to_requester(self, msg): | ||
worker_id, _, result = msg[:3] | ||
if result != b"READY": | ||
self._req_receiver.send_multipart(msg[2:]) | ||
else: | ||
print(f"{worker_id} ready") | ||
|
||
def start(self): | ||
self._event_loop.start() | ||
|
||
def stop(self): | ||
self._event_loop.stop() | ||
|
||
@staticmethod | ||
def log_route_request(msg, status): | ||
worker_id, _, ops_name, _, req = msg | ||
req = bytes_to_pyobj(req) | ||
print(f"Routed {ops_name}'s request {req['func']} to worker node {worker_id}") | ||
|
||
@staticmethod | ||
def log_send_result(msg, status): | ||
print(f"Returned result for {msg[0]}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Tuple | ||
|
||
import zmq | ||
from zmq.asyncio import Context | ||
|
||
from .utils import bytes_to_pyobj, pyobj_to_bytes, string_to_bytes | ||
|
||
|
||
def remote_method(ops_name, func_name: str, dispatcher_address): | ||
async def remote_call(*args, **kwargs): | ||
req = {"func": func_name, "args": args, "kwargs": kwargs} | ||
context = Context.instance() | ||
sock = context.socket(zmq.REQ) | ||
sock.identity = string_to_bytes(ops_name) | ||
sock.connect(dispatcher_address) | ||
sock.send(pyobj_to_bytes(req)) | ||
print(f"sent request {func_name} for {ops_name}") | ||
result = bytes_to_pyobj(await sock.recv()) | ||
print(f"result for request {func_name} for {ops_name}: {result}") | ||
sock.close() | ||
return result | ||
|
||
return remote_call | ||
|
||
|
||
class RemoteOps(object): | ||
def __init__(self, ops_name, dispatcher_address: Tuple[str, int]): | ||
self._ops_name = ops_name | ||
# self._functions = {name for name, _ in inspect.getmembers(train_op_cls, lambda attr: inspect.isfunction(attr))} | ||
host, port = dispatcher_address | ||
self._dispatcher_address = f"tcp://{host}:{port}" | ||
|
||
def __getattribute__(self, attr_name: str): | ||
# Ignore methods that belong to the parent class | ||
try: | ||
return super().__getattribute__(attr_name) | ||
except AttributeError: | ||
pass | ||
|
||
return remote_method(self._ops_name, attr_name, self._dispatcher_address) |
Oops, something went wrong.