Skip to content

Commit

Permalink
Several code style refinements (#451)
Browse files Browse the repository at this point in the history
* Polish rl_v3/utils/

* Polish rl_v3/distributed/

* Polish rl_v3/policy_trainer/abs_trainer.py
  • Loading branch information
lihuoran authored Dec 28, 2021
1 parent a904142 commit 02ac9cc
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 56 deletions.
25 changes: 16 additions & 9 deletions maro/rl_v3/distributed/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, Dict

import zmq
from tornado.ioloop import IOLoop
Expand All @@ -9,7 +9,14 @@


class Dispatcher(object):
def __init__(self, host: str, num_workers: int, frontend_port: int = 10000, backend_port: int = 10001, hash_fn: Callable = hash):
def __init__(
self,
host: str,
num_workers: int,
frontend_port: int = 10000,
backend_port: int = 10001,
hash_fn: Callable[[str], int] = hash
) -> None:
# ZMQ sockets and streams
self._context = Context.instance()
self._req_socket = self._context.socket(zmq.ROUTER)
Expand All @@ -30,9 +37,9 @@ def __init__(self, host: str, num_workers: int, frontend_port: int = 10000, back
# bookkeeping
self._num_workers = num_workers
self._hash_fn = hash_fn
self._ops2node = {}
self._ops2node: Dict[str, int] = {}

def _route_request_to_compute_node(self, msg):
def _route_request_to_compute_node(self, msg: list) -> None:
ops_name, _, req = msg
print(f"Received request from {ops_name}")
if ops_name not in self._ops2node:
Expand All @@ -41,25 +48,25 @@ def _route_request_to_compute_node(self, msg):
worker_id = f'worker.{self._ops2node[ops_name]}'
self._router.send_multipart([string_to_bytes(worker_id), b"", ops_name, b"", req])

def _send_result_to_requester(self, msg):
def _send_result_to_requester(self, msg: list) -> None:
worker_id, _, result = msg[:3]
if result != b"READY":
self._req_receiver.send_multipart(msg[2:])
else:
print(f"{worker_id} ready")

def start(self):
def start(self) -> None:
self._event_loop.start()

def stop(self):
def stop(self) -> None:
self._event_loop.stop()

@staticmethod
def log_route_request(msg, status):
def log_route_request(msg: list, status: object) -> None:
worker_id, _, ops_name, _, req = msg
req = bytes_to_pyobj(req)
print(f"Routed {ops_name}'s request {req['func']} to worker node {worker_id}")

@staticmethod
def log_send_result(msg, status):
def log_send_result(msg: list, status: object) -> None:
print(f"Returned result for {msg[0]}")
10 changes: 6 additions & 4 deletions maro/rl_v3/distributed/remote.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Tuple
from typing import Callable, Tuple

import zmq
from zmq.asyncio import Context

from .utils import bytes_to_pyobj, pyobj_to_bytes, string_to_bytes


def remote_method(ops_name, func_name: str, dispatcher_address):
def remote_method(ops_name: str, func_name: str, dispatcher_address: str) -> Callable:
async def remote_call(*args, **kwargs):
req = {"func": func_name, "args": args, "kwargs": kwargs}
context = Context.instance()
Expand All @@ -26,11 +26,13 @@ async def remote_call(*args, **kwargs):
class RemoteOps(object):
def __init__(self, ops_name, dispatcher_address: Tuple[str, int]):
self._ops_name = ops_name
# self._functions = {name for name, _ in inspect.getmembers(train_op_cls, lambda attr: inspect.isfunction(attr))}
# self._functions = {
# name for name, _ in inspect.getmembers(train_op_cls, lambda attr: inspect.isfunction(attr))
# }
host, port = dispatcher_address
self._dispatcher_address = f"tcp://{host}:{port}"

def __getattribute__(self, attr_name: str):
def __getattribute__(self, attr_name: str) -> object:
# Ignore methods that belong to the parent class
try:
return super().__getattribute__(attr_name)
Expand Down
23 changes: 16 additions & 7 deletions maro/rl_v3/distributed/train_ops_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@


class TrainOpsWorker(object):
def __init__(self, idx: int, ops_creator: Dict[str, Callable], router_host: str, router_port: int = 10001):
def __init__(
self,
idx: int,
ops_creator: Dict[str, Callable[[str], object]], # TODO: Callable type?
router_host: str,
router_port: int = 10001
) -> None:
# ZMQ sockets and streams
self._id = f"worker.{idx}"
self._ops_creator = ops_creator
Expand All @@ -27,25 +33,28 @@ def __init__(self, idx: int, ops_creator: Dict[str, Callable], router_host: str,
self._task_receiver.on_recv(self._compute)
self._task_receiver.on_send(self.log_send_result)

self._ops_dict = {}
self._ops_dict: Dict[str, object] = {} # TODO: value type?

def _compute(self, msg):
def _compute(self, msg: list) -> None:
ops_name = bytes_to_string(msg[1])
req = bytes_to_pyobj(msg[-1])
assert isinstance(req, dict)

if ops_name not in self._ops_dict:
self._ops_dict[ops_name] = self._ops_creator[ops_name](ops_name)
print(f"Created ops instance {ops_name} at worker {self._id}")

func_name, args, kwargs = req["func"], req["args"], req["kwargs"]
result = getattr(self._ops_dict[ops_name], func_name)(*args, **kwargs)
func = getattr(self._ops_dict[ops_name], func_name)
result = func(*args, **kwargs)
self._task_receiver.send_multipart([b"", msg[1], b"", pyobj_to_bytes(result)])

def start(self):
def start(self) -> None:
self._event_loop.start()

def stop(self):
def stop(self) -> None:
self._event_loop.stop()

@staticmethod
def log_send_result(msg, status):
def log_send_result(msg: list, status: object) -> None:
print(f"Returning result for {msg[1]}")
11 changes: 6 additions & 5 deletions maro/rl_v3/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import asyncio
import pickle
from collections import Callable

DEFAULT_MSG_ENCODING = "utf-8"


def string_to_bytes(s: str):
def string_to_bytes(s: str) -> bytes:
return s.encode(DEFAULT_MSG_ENCODING)


def bytes_to_string(bytes_: bytes):
def bytes_to_string(bytes_: bytes) -> str:
return bytes_.decode(DEFAULT_MSG_ENCODING)


def pyobj_to_bytes(pyobj):
def pyobj_to_bytes(pyobj) -> bytes:
return pickle.dumps(pyobj)


def bytes_to_pyobj(bytes_: bytes):
def bytes_to_pyobj(bytes_: bytes) -> object:
return pickle.loads(bytes_)


def sync(async_func, *args, **kwargs):
def sync(async_func: Callable, *args, **kwargs) -> object:
return asyncio.get_event_loop().run_until_complete(async_func(*args, **kwargs))
52 changes: 28 additions & 24 deletions maro/rl_v3/policy_trainer/abs_trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple

import torch

from maro.rl_v3.distributed.remote import RemoteOps
from maro.rl_v3.replay_memory import MultiReplayMemory, ReplayMemory
from maro.rl_v3.utils import MultiTransitionBatch, TransitionBatch
from maro.rl_v3.utils import AbsTransitionBatch, MultiTransitionBatch, TransitionBatch


class AbsTrainer(object, metaclass=ABCMeta):
Expand Down Expand Up @@ -41,7 +41,7 @@ def name(self) -> str:
return self._name

@abstractmethod
def train_step(self) -> None:
async def train_step(self) -> None:
"""
Run a training step to update all the policies that this trainer is responsible for.
"""
Expand Down Expand Up @@ -75,22 +75,22 @@ class SingleTrainer(AbsTrainer, metaclass=ABCMeta):
def __init__(
self,
name: str,
ops_creator: Dict[str, Callable],
ops_creator: Dict[str, Callable], # TODO
dispatcher_address: Tuple[str, int] = None,
device: str = None,
enable_data_parallelism: bool = False,
train_batch_size: int = 128
) -> None:
super(SingleTrainer, self).__init__(name, device, enable_data_parallelism, train_batch_size)

self._replay_memory: Optional[ReplayMemory] = None
ops_name = [name for name in ops_creator if name.startswith(f"{self._name}.")]
if len(ops_name) > 1:

ops_names = [ops_name for ops_name in ops_creator if ops_name.startswith(f"{self._name}.")]
if len(ops_names) > 1:
raise ValueError(f"trainer {self._name} cannot have more than one policy assigned to it")
ops_name = ops_name.pop()
if dispatcher_address:
self._ops = RemoteOps(ops_name, dispatcher_address)
else:
self._ops = ops_creator[ops_name](ops_name)

ops_name = ops_names.pop()
self._ops = RemoteOps(ops_name, dispatcher_address) if dispatcher_address else ops_creator[ops_name](ops_name)

def record(self, transition_batch: TransitionBatch) -> None:
"""
Expand All @@ -99,6 +99,7 @@ def record(self, transition_batch: TransitionBatch) -> None:
Args:
transition_batch (TransitionBatch): A TransitionBatch item that contains a batch of experiences.
"""
assert isinstance(transition_batch, TransitionBatch)
self._replay_memory.put(transition_batch)

def _get_batch(self, batch_size: int = None) -> TransitionBatch:
Expand Down Expand Up @@ -130,18 +131,21 @@ def __init__(
train_batch_size: int = 128
) -> None:
super(MultiTrainer, self).__init__(name, device, enable_data_parallelism, train_batch_size)

self._replay_memory: Optional[MultiReplayMemory] = None
ops_names = [name for name in ops_creator if name.startswith(f"{self._name}.")]
if len(ops_names) < 2:
raise ValueError(f"trainer {self._name} cannot less than 2 policies assigned to it")
if dispatcher_address:
self._ops_list = [RemoteOps(ops_name, dispatcher_address) for ops_name in ops_names]
else:
self._ops_list = [ops_creator[ops_name](ops_name) for ops_name in ops_names]

ops_names = [ops_name for ops_name in ops_creator if ops_name.startswith(f"{self._name}.")]
# if len(ops_names) < 2:
# raise ValueError(f"trainer {self._name} cannot less than 2 policies assigned to it")

self._ops_list = [
RemoteOps(ops_name, dispatcher_address) if dispatcher_address else ops_creator[ops_name](ops_name)
for ops_name in ops_names
]

@property
def num_policies(self) -> int:
return len(self._ops_list)
return len(self._ops_list) # TODO

def record(self, transition_batch: MultiTransitionBatch) -> None:
"""
Expand All @@ -157,20 +161,20 @@ def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch:


class BatchTrainer:
def __init__(self, trainers: List[Union[SingleTrainer, MultiTrainer]]):
def __init__(self, trainers: List[AbsTrainer]) -> None:
self._trainers = trainers
self._trainer_dict = {trainer.name: trainer for trainer in self._trainers}

def record(self, batch_by_trainer: dict):
def record(self, batch_by_trainer: Dict[str, AbsTransitionBatch]) -> None:
for trainer_name, batch in batch_by_trainer.items():
self._trainer_dict[trainer_name].record(batch)

def train(self):
def train(self) -> None:
try:
asyncio.run(self._train())
asyncio.run(self._train_impl())
except TypeError:
for trainer in self._trainers:
trainer.train_step()

async def _train(self):
async def _train_impl(self) -> None:
await asyncio.gather(*[trainer.train_step() for trainer in self._trainers])
8 changes: 7 additions & 1 deletion maro/rl_v3/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Union

from .objects import SHAPE_CHECK_FLAG
from .torch_util import match_shape, ndarray_to_tensor
from .transition_batch import MultiTransitionBatch, TransitionBatch


AbsTransitionBatch = Union[TransitionBatch, MultiTransitionBatch]


__all__ = [
"SHAPE_CHECK_FLAG",
"match_shape", "ndarray_to_tensor",
"MultiTransitionBatch", "TransitionBatch"
"AbsTransitionBatch", "MultiTransitionBatch", "TransitionBatch"
]
13 changes: 7 additions & 6 deletions maro/rl_v3/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import importlib
import os
import sys
from types import ModuleType
from typing import List, Union

from maro.utils import Logger


def from_env(var_name, required=True, default=None):
def from_env(var_name: str, required: bool = True, default: object = None) -> object:
if var_name not in os.environ:
if required:
raise KeyError(f"Missing environment variable: {var_name}")
Expand All @@ -26,7 +27,7 @@ def from_env(var_name, required=True, default=None):
return var


def get_eval_schedule(sch: Union[int, List[int]], num_episodes: int):
def get_eval_schedule(sch: Union[int, List[int]], num_episodes: int) -> List[int]:
"""Helper function to the policy evaluation schedule.
Args:
Expand All @@ -51,21 +52,21 @@ def get_eval_schedule(sch: Union[int, List[int]], num_episodes: int):
return schedule


def get_module(path: str):
def get_module(path: str) -> ModuleType:
path = os.path.normpath(path)
sys.path.insert(0, os.path.dirname(path))
return importlib.import_module(os.path.basename(path))


def get_log_path(dir: str, job_name: str):
def get_log_path(dir: str, job_name: str) -> str:
return os.path.join(dir, f"{job_name}.log")


def get_logger(dir: str, job_name: str, tag: str):
def get_logger(dir: str, job_name: str, tag: str) -> Logger:
return Logger(tag, dump_path=get_log_path(dir, job_name), dump_mode="a")


def get_checkpoint_path(dir: str = None):
def get_checkpoint_path(dir: str = None) -> str:
if dir:
os.makedirs(dir, exist_ok=True)
return dir

0 comments on commit 02ac9cc

Please sign in to comment.