Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several code style refinements #451

Merged
merged 3 commits into from
Dec 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions maro/rl_v3/distributed/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, Dict

import zmq
from tornado.ioloop import IOLoop
Expand All @@ -9,7 +9,14 @@


class Dispatcher(object):
def __init__(self, host: str, num_workers: int, frontend_port: int = 10000, backend_port: int = 10001, hash_fn: Callable = hash):
def __init__(
self,
host: str,
num_workers: int,
frontend_port: int = 10000,
backend_port: int = 10001,
hash_fn: Callable[[str], int] = hash
) -> None:
# ZMQ sockets and streams
self._context = Context.instance()
self._req_socket = self._context.socket(zmq.ROUTER)
Expand All @@ -30,9 +37,9 @@ def __init__(self, host: str, num_workers: int, frontend_port: int = 10000, back
# bookkeeping
self._num_workers = num_workers
self._hash_fn = hash_fn
self._ops2node = {}
self._ops2node: Dict[str, int] = {}

def _route_request_to_compute_node(self, msg):
def _route_request_to_compute_node(self, msg: list) -> None:
ops_name, _, req = msg
print(f"Received request from {ops_name}")
if ops_name not in self._ops2node:
Expand All @@ -41,25 +48,25 @@ def _route_request_to_compute_node(self, msg):
worker_id = f'worker.{self._ops2node[ops_name]}'
self._router.send_multipart([string_to_bytes(worker_id), b"", ops_name, b"", req])

def _send_result_to_requester(self, msg):
def _send_result_to_requester(self, msg: list) -> None:
worker_id, _, result = msg[:3]
if result != b"READY":
self._req_receiver.send_multipart(msg[2:])
else:
print(f"{worker_id} ready")

def start(self):
def start(self) -> None:
self._event_loop.start()

def stop(self):
def stop(self) -> None:
self._event_loop.stop()

@staticmethod
def log_route_request(msg, status):
def log_route_request(msg: list, status: object) -> None:
worker_id, _, ops_name, _, req = msg
req = bytes_to_pyobj(req)
print(f"Routed {ops_name}'s request {req['func']} to worker node {worker_id}")

@staticmethod
def log_send_result(msg, status):
def log_send_result(msg: list, status: object) -> None:
print(f"Returned result for {msg[0]}")
10 changes: 6 additions & 4 deletions maro/rl_v3/distributed/remote.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Tuple
from typing import Callable, Tuple

import zmq
from zmq.asyncio import Context

from .utils import bytes_to_pyobj, pyobj_to_bytes, string_to_bytes


def remote_method(ops_name, func_name: str, dispatcher_address):
def remote_method(ops_name: str, func_name: str, dispatcher_address: str) -> Callable:
async def remote_call(*args, **kwargs):
req = {"func": func_name, "args": args, "kwargs": kwargs}
context = Context.instance()
Expand All @@ -26,11 +26,13 @@ async def remote_call(*args, **kwargs):
class RemoteOps(object):
def __init__(self, ops_name, dispatcher_address: Tuple[str, int]):
self._ops_name = ops_name
# self._functions = {name for name, _ in inspect.getmembers(train_op_cls, lambda attr: inspect.isfunction(attr))}
# self._functions = {
# name for name, _ in inspect.getmembers(train_op_cls, lambda attr: inspect.isfunction(attr))
# }
host, port = dispatcher_address
self._dispatcher_address = f"tcp://{host}:{port}"

def __getattribute__(self, attr_name: str):
def __getattribute__(self, attr_name: str) -> object:
# Ignore methods that belong to the parent class
try:
return super().__getattribute__(attr_name)
Expand Down
23 changes: 16 additions & 7 deletions maro/rl_v3/distributed/train_ops_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@


class TrainOpsWorker(object):
def __init__(self, idx: int, ops_creator: Dict[str, Callable], router_host: str, router_port: int = 10001):
def __init__(
self,
idx: int,
ops_creator: Dict[str, Callable[[str], object]], # TODO: Callable type?
router_host: str,
router_port: int = 10001
) -> None:
# ZMQ sockets and streams
self._id = f"worker.{idx}"
self._ops_creator = ops_creator
Expand All @@ -27,25 +33,28 @@ def __init__(self, idx: int, ops_creator: Dict[str, Callable], router_host: str,
self._task_receiver.on_recv(self._compute)
self._task_receiver.on_send(self.log_send_result)

self._ops_dict = {}
self._ops_dict: Dict[str, object] = {} # TODO: value type?

def _compute(self, msg):
def _compute(self, msg: list) -> None:
ops_name = bytes_to_string(msg[1])
req = bytes_to_pyobj(msg[-1])
assert isinstance(req, dict)

if ops_name not in self._ops_dict:
self._ops_dict[ops_name] = self._ops_creator[ops_name](ops_name)
print(f"Created ops instance {ops_name} at worker {self._id}")

func_name, args, kwargs = req["func"], req["args"], req["kwargs"]
result = getattr(self._ops_dict[ops_name], func_name)(*args, **kwargs)
func = getattr(self._ops_dict[ops_name], func_name)
result = func(*args, **kwargs)
self._task_receiver.send_multipart([b"", msg[1], b"", pyobj_to_bytes(result)])

def start(self):
def start(self) -> None:
self._event_loop.start()

def stop(self):
def stop(self) -> None:
self._event_loop.stop()

@staticmethod
def log_send_result(msg, status):
def log_send_result(msg: list, status: object) -> None:
print(f"Returning result for {msg[1]}")
11 changes: 6 additions & 5 deletions maro/rl_v3/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import asyncio
import pickle
from collections import Callable

DEFAULT_MSG_ENCODING = "utf-8"


def string_to_bytes(s: str):
def string_to_bytes(s: str) -> bytes:
return s.encode(DEFAULT_MSG_ENCODING)


def bytes_to_string(bytes_: bytes):
def bytes_to_string(bytes_: bytes) -> str:
return bytes_.decode(DEFAULT_MSG_ENCODING)


def pyobj_to_bytes(pyobj):
def pyobj_to_bytes(pyobj) -> bytes:
return pickle.dumps(pyobj)


def bytes_to_pyobj(bytes_: bytes):
def bytes_to_pyobj(bytes_: bytes) -> object:
return pickle.loads(bytes_)


def sync(async_func, *args, **kwargs):
def sync(async_func: Callable, *args, **kwargs) -> object:
return asyncio.get_event_loop().run_until_complete(async_func(*args, **kwargs))
52 changes: 28 additions & 24 deletions maro/rl_v3/policy_trainer/abs_trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple

import torch

from maro.rl_v3.distributed.remote import RemoteOps
from maro.rl_v3.replay_memory import MultiReplayMemory, ReplayMemory
from maro.rl_v3.utils import MultiTransitionBatch, TransitionBatch
from maro.rl_v3.utils import AbsTransitionBatch, MultiTransitionBatch, TransitionBatch


class AbsTrainer(object, metaclass=ABCMeta):
Expand Down Expand Up @@ -41,7 +41,7 @@ def name(self) -> str:
return self._name

@abstractmethod
def train_step(self) -> None:
async def train_step(self) -> None:
"""
Run a training step to update all the policies that this trainer is responsible for.
"""
Expand Down Expand Up @@ -75,22 +75,22 @@ class SingleTrainer(AbsTrainer, metaclass=ABCMeta):
def __init__(
self,
name: str,
ops_creator: Dict[str, Callable],
ops_creator: Dict[str, Callable], # TODO
dispatcher_address: Tuple[str, int] = None,
device: str = None,
enable_data_parallelism: bool = False,
train_batch_size: int = 128
) -> None:
super(SingleTrainer, self).__init__(name, device, enable_data_parallelism, train_batch_size)

self._replay_memory: Optional[ReplayMemory] = None
ops_name = [name for name in ops_creator if name.startswith(f"{self._name}.")]
if len(ops_name) > 1:

ops_names = [ops_name for ops_name in ops_creator if ops_name.startswith(f"{self._name}.")]
if len(ops_names) > 1:
raise ValueError(f"trainer {self._name} cannot have more than one policy assigned to it")
ops_name = ops_name.pop()
if dispatcher_address:
self._ops = RemoteOps(ops_name, dispatcher_address)
else:
self._ops = ops_creator[ops_name](ops_name)

ops_name = ops_names.pop()
self._ops = RemoteOps(ops_name, dispatcher_address) if dispatcher_address else ops_creator[ops_name](ops_name)

def record(self, transition_batch: TransitionBatch) -> None:
"""
Expand All @@ -99,6 +99,7 @@ def record(self, transition_batch: TransitionBatch) -> None:
Args:
transition_batch (TransitionBatch): A TransitionBatch item that contains a batch of experiences.
"""
assert isinstance(transition_batch, TransitionBatch)
self._replay_memory.put(transition_batch)

def _get_batch(self, batch_size: int = None) -> TransitionBatch:
Expand Down Expand Up @@ -130,18 +131,21 @@ def __init__(
train_batch_size: int = 128
) -> None:
super(MultiTrainer, self).__init__(name, device, enable_data_parallelism, train_batch_size)

self._replay_memory: Optional[MultiReplayMemory] = None
ops_names = [name for name in ops_creator if name.startswith(f"{self._name}.")]
if len(ops_names) < 2:
raise ValueError(f"trainer {self._name} cannot less than 2 policies assigned to it")
if dispatcher_address:
self._ops_list = [RemoteOps(ops_name, dispatcher_address) for ops_name in ops_names]
else:
self._ops_list = [ops_creator[ops_name](ops_name) for ops_name in ops_names]

ops_names = [ops_name for ops_name in ops_creator if ops_name.startswith(f"{self._name}.")]
# if len(ops_names) < 2:
# raise ValueError(f"trainer {self._name} cannot less than 2 policies assigned to it")

self._ops_list = [
RemoteOps(ops_name, dispatcher_address) if dispatcher_address else ops_creator[ops_name](ops_name)
for ops_name in ops_names
]

@property
def num_policies(self) -> int:
return len(self._ops_list)
return len(self._ops_list) # TODO

def record(self, transition_batch: MultiTransitionBatch) -> None:
"""
Expand All @@ -157,20 +161,20 @@ def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch:


class BatchTrainer:
def __init__(self, trainers: List[Union[SingleTrainer, MultiTrainer]]):
def __init__(self, trainers: List[AbsTrainer]) -> None:
self._trainers = trainers
self._trainer_dict = {trainer.name: trainer for trainer in self._trainers}

def record(self, batch_by_trainer: dict):
def record(self, batch_by_trainer: Dict[str, AbsTransitionBatch]) -> None:
for trainer_name, batch in batch_by_trainer.items():
self._trainer_dict[trainer_name].record(batch)

def train(self):
def train(self) -> None:
try:
asyncio.run(self._train())
asyncio.run(self._train_impl())
except TypeError:
for trainer in self._trainers:
trainer.train_step()

async def _train(self):
async def _train_impl(self) -> None:
await asyncio.gather(*[trainer.train_step() for trainer in self._trainers])
8 changes: 7 additions & 1 deletion maro/rl_v3/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Union

from .objects import SHAPE_CHECK_FLAG
from .torch_util import match_shape, ndarray_to_tensor
from .transition_batch import MultiTransitionBatch, TransitionBatch


AbsTransitionBatch = Union[TransitionBatch, MultiTransitionBatch]


__all__ = [
"SHAPE_CHECK_FLAG",
"match_shape", "ndarray_to_tensor",
"MultiTransitionBatch", "TransitionBatch"
"AbsTransitionBatch", "MultiTransitionBatch", "TransitionBatch"
]
13 changes: 7 additions & 6 deletions maro/rl_v3/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import importlib
import os
import sys
from types import ModuleType
from typing import List, Union

from maro.utils import Logger


def from_env(var_name, required=True, default=None):
def from_env(var_name: str, required: bool = True, default: object = None) -> object:
if var_name not in os.environ:
if required:
raise KeyError(f"Missing environment variable: {var_name}")
Expand All @@ -26,7 +27,7 @@ def from_env(var_name, required=True, default=None):
return var


def get_eval_schedule(sch: Union[int, List[int]], num_episodes: int):
def get_eval_schedule(sch: Union[int, List[int]], num_episodes: int) -> List[int]:
"""Helper function to the policy evaluation schedule.

Args:
Expand All @@ -51,21 +52,21 @@ def get_eval_schedule(sch: Union[int, List[int]], num_episodes: int):
return schedule


def get_module(path: str):
def get_module(path: str) -> ModuleType:
path = os.path.normpath(path)
sys.path.insert(0, os.path.dirname(path))
return importlib.import_module(os.path.basename(path))


def get_log_path(dir: str, job_name: str):
def get_log_path(dir: str, job_name: str) -> str:
return os.path.join(dir, f"{job_name}.log")


def get_logger(dir: str, job_name: str, tag: str):
def get_logger(dir: str, job_name: str, tag: str) -> Logger:
return Logger(tag, dump_path=get_log_path(dir, job_name), dump_mode="a")


def get_checkpoint_path(dir: str = None):
def get_checkpoint_path(dir: str = None) -> str:
if dir:
os.makedirs(dir, exist_ok=True)
return dir