From 02ac9cc9d2eb6d6c6039456001343472b9b18728 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 28 Dec 2021 17:10:12 +0800 Subject: [PATCH] Several code style refinements (#451) * Polish rl_v3/utils/ * Polish rl_v3/distributed/ * Polish rl_v3/policy_trainer/abs_trainer.py --- maro/rl_v3/distributed/dispatcher.py | 25 +++++++---- maro/rl_v3/distributed/remote.py | 10 +++-- maro/rl_v3/distributed/train_ops_worker.py | 23 +++++++--- maro/rl_v3/distributed/utils.py | 11 ++--- maro/rl_v3/policy_trainer/abs_trainer.py | 52 ++++++++++++---------- maro/rl_v3/utils/__init__.py | 8 +++- maro/rl_v3/utils/common.py | 13 +++--- 7 files changed, 86 insertions(+), 56 deletions(-) diff --git a/maro/rl_v3/distributed/dispatcher.py b/maro/rl_v3/distributed/dispatcher.py index ecc5027df..36806f84a 100644 --- a/maro/rl_v3/distributed/dispatcher.py +++ b/maro/rl_v3/distributed/dispatcher.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Dict import zmq from tornado.ioloop import IOLoop @@ -9,7 +9,14 @@ class Dispatcher(object): - def __init__(self, host: str, num_workers: int, frontend_port: int = 10000, backend_port: int = 10001, hash_fn: Callable = hash): + def __init__( + self, + host: str, + num_workers: int, + frontend_port: int = 10000, + backend_port: int = 10001, + hash_fn: Callable[[str], int] = hash + ) -> None: # ZMQ sockets and streams self._context = Context.instance() self._req_socket = self._context.socket(zmq.ROUTER) @@ -30,9 +37,9 @@ def __init__(self, host: str, num_workers: int, frontend_port: int = 10000, back # bookkeeping self._num_workers = num_workers self._hash_fn = hash_fn - self._ops2node = {} + self._ops2node: Dict[str, int] = {} - def _route_request_to_compute_node(self, msg): + def _route_request_to_compute_node(self, msg: list) -> None: ops_name, _, req = msg print(f"Received request from {ops_name}") if ops_name not in self._ops2node: @@ -41,25 +48,25 @@ def _route_request_to_compute_node(self, msg): worker_id = f'worker.{self._ops2node[ops_name]}' self._router.send_multipart([string_to_bytes(worker_id), b"", ops_name, b"", req]) - def _send_result_to_requester(self, msg): + def _send_result_to_requester(self, msg: list) -> None: worker_id, _, result = msg[:3] if result != b"READY": self._req_receiver.send_multipart(msg[2:]) else: print(f"{worker_id} ready") - def start(self): + def start(self) -> None: self._event_loop.start() - def stop(self): + def stop(self) -> None: self._event_loop.stop() @staticmethod - def log_route_request(msg, status): + def log_route_request(msg: list, status: object) -> None: worker_id, _, ops_name, _, req = msg req = bytes_to_pyobj(req) print(f"Routed {ops_name}'s request {req['func']} to worker node {worker_id}") @staticmethod - def log_send_result(msg, status): + def log_send_result(msg: list, status: object) -> None: print(f"Returned result for {msg[0]}") diff --git a/maro/rl_v3/distributed/remote.py b/maro/rl_v3/distributed/remote.py index 755615508..c7573a2cc 100644 --- a/maro/rl_v3/distributed/remote.py +++ b/maro/rl_v3/distributed/remote.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Callable, Tuple import zmq from zmq.asyncio import Context @@ -6,7 +6,7 @@ from .utils import bytes_to_pyobj, pyobj_to_bytes, string_to_bytes -def remote_method(ops_name, func_name: str, dispatcher_address): +def remote_method(ops_name: str, func_name: str, dispatcher_address: str) -> Callable: async def remote_call(*args, **kwargs): req = {"func": func_name, "args": args, "kwargs": kwargs} context = Context.instance() @@ -26,11 +26,13 @@ async def remote_call(*args, **kwargs): class RemoteOps(object): def __init__(self, ops_name, dispatcher_address: Tuple[str, int]): self._ops_name = ops_name - # self._functions = {name for name, _ in inspect.getmembers(train_op_cls, lambda attr: inspect.isfunction(attr))} + # self._functions = { + # name for name, _ in inspect.getmembers(train_op_cls, lambda attr: inspect.isfunction(attr)) + # } host, port = dispatcher_address self._dispatcher_address = f"tcp://{host}:{port}" - def __getattribute__(self, attr_name: str): + def __getattribute__(self, attr_name: str) -> object: # Ignore methods that belong to the parent class try: return super().__getattribute__(attr_name) diff --git a/maro/rl_v3/distributed/train_ops_worker.py b/maro/rl_v3/distributed/train_ops_worker.py index f5b6f76eb..e4a04c393 100644 --- a/maro/rl_v3/distributed/train_ops_worker.py +++ b/maro/rl_v3/distributed/train_ops_worker.py @@ -9,7 +9,13 @@ class TrainOpsWorker(object): - def __init__(self, idx: int, ops_creator: Dict[str, Callable], router_host: str, router_port: int = 10001): + def __init__( + self, + idx: int, + ops_creator: Dict[str, Callable[[str], object]], # TODO: Callable type? + router_host: str, + router_port: int = 10001 + ) -> None: # ZMQ sockets and streams self._id = f"worker.{idx}" self._ops_creator = ops_creator @@ -27,25 +33,28 @@ def __init__(self, idx: int, ops_creator: Dict[str, Callable], router_host: str, self._task_receiver.on_recv(self._compute) self._task_receiver.on_send(self.log_send_result) - self._ops_dict = {} + self._ops_dict: Dict[str, object] = {} # TODO: value type? - def _compute(self, msg): + def _compute(self, msg: list) -> None: ops_name = bytes_to_string(msg[1]) req = bytes_to_pyobj(msg[-1]) + assert isinstance(req, dict) + if ops_name not in self._ops_dict: self._ops_dict[ops_name] = self._ops_creator[ops_name](ops_name) print(f"Created ops instance {ops_name} at worker {self._id}") func_name, args, kwargs = req["func"], req["args"], req["kwargs"] - result = getattr(self._ops_dict[ops_name], func_name)(*args, **kwargs) + func = getattr(self._ops_dict[ops_name], func_name) + result = func(*args, **kwargs) self._task_receiver.send_multipart([b"", msg[1], b"", pyobj_to_bytes(result)]) - def start(self): + def start(self) -> None: self._event_loop.start() - def stop(self): + def stop(self) -> None: self._event_loop.stop() @staticmethod - def log_send_result(msg, status): + def log_send_result(msg: list, status: object) -> None: print(f"Returning result for {msg[1]}") diff --git a/maro/rl_v3/distributed/utils.py b/maro/rl_v3/distributed/utils.py index 33c3bf431..b74c4f8ab 100644 --- a/maro/rl_v3/distributed/utils.py +++ b/maro/rl_v3/distributed/utils.py @@ -1,24 +1,25 @@ import asyncio import pickle +from collections import Callable DEFAULT_MSG_ENCODING = "utf-8" -def string_to_bytes(s: str): +def string_to_bytes(s: str) -> bytes: return s.encode(DEFAULT_MSG_ENCODING) -def bytes_to_string(bytes_: bytes): +def bytes_to_string(bytes_: bytes) -> str: return bytes_.decode(DEFAULT_MSG_ENCODING) -def pyobj_to_bytes(pyobj): +def pyobj_to_bytes(pyobj) -> bytes: return pickle.dumps(pyobj) -def bytes_to_pyobj(bytes_: bytes): +def bytes_to_pyobj(bytes_: bytes) -> object: return pickle.loads(bytes_) -def sync(async_func, *args, **kwargs): +def sync(async_func: Callable, *args, **kwargs) -> object: return asyncio.get_event_loop().run_until_complete(async_func(*args, **kwargs)) diff --git a/maro/rl_v3/policy_trainer/abs_trainer.py b/maro/rl_v3/policy_trainer/abs_trainer.py index 8737d1849..48ac207f1 100644 --- a/maro/rl_v3/policy_trainer/abs_trainer.py +++ b/maro/rl_v3/policy_trainer/abs_trainer.py @@ -1,12 +1,12 @@ import asyncio from abc import ABCMeta, abstractmethod -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple import torch from maro.rl_v3.distributed.remote import RemoteOps from maro.rl_v3.replay_memory import MultiReplayMemory, ReplayMemory -from maro.rl_v3.utils import MultiTransitionBatch, TransitionBatch +from maro.rl_v3.utils import AbsTransitionBatch, MultiTransitionBatch, TransitionBatch class AbsTrainer(object, metaclass=ABCMeta): @@ -41,7 +41,7 @@ def name(self) -> str: return self._name @abstractmethod - def train_step(self) -> None: + async def train_step(self) -> None: """ Run a training step to update all the policies that this trainer is responsible for. """ @@ -75,22 +75,22 @@ class SingleTrainer(AbsTrainer, metaclass=ABCMeta): def __init__( self, name: str, - ops_creator: Dict[str, Callable], + ops_creator: Dict[str, Callable], # TODO dispatcher_address: Tuple[str, int] = None, device: str = None, enable_data_parallelism: bool = False, train_batch_size: int = 128 ) -> None: super(SingleTrainer, self).__init__(name, device, enable_data_parallelism, train_batch_size) + self._replay_memory: Optional[ReplayMemory] = None - ops_name = [name for name in ops_creator if name.startswith(f"{self._name}.")] - if len(ops_name) > 1: + + ops_names = [ops_name for ops_name in ops_creator if ops_name.startswith(f"{self._name}.")] + if len(ops_names) > 1: raise ValueError(f"trainer {self._name} cannot have more than one policy assigned to it") - ops_name = ops_name.pop() - if dispatcher_address: - self._ops = RemoteOps(ops_name, dispatcher_address) - else: - self._ops = ops_creator[ops_name](ops_name) + + ops_name = ops_names.pop() + self._ops = RemoteOps(ops_name, dispatcher_address) if dispatcher_address else ops_creator[ops_name](ops_name) def record(self, transition_batch: TransitionBatch) -> None: """ @@ -99,6 +99,7 @@ def record(self, transition_batch: TransitionBatch) -> None: Args: transition_batch (TransitionBatch): A TransitionBatch item that contains a batch of experiences. """ + assert isinstance(transition_batch, TransitionBatch) self._replay_memory.put(transition_batch) def _get_batch(self, batch_size: int = None) -> TransitionBatch: @@ -130,18 +131,21 @@ def __init__( train_batch_size: int = 128 ) -> None: super(MultiTrainer, self).__init__(name, device, enable_data_parallelism, train_batch_size) + self._replay_memory: Optional[MultiReplayMemory] = None - ops_names = [name for name in ops_creator if name.startswith(f"{self._name}.")] - if len(ops_names) < 2: - raise ValueError(f"trainer {self._name} cannot less than 2 policies assigned to it") - if dispatcher_address: - self._ops_list = [RemoteOps(ops_name, dispatcher_address) for ops_name in ops_names] - else: - self._ops_list = [ops_creator[ops_name](ops_name) for ops_name in ops_names] + + ops_names = [ops_name for ops_name in ops_creator if ops_name.startswith(f"{self._name}.")] + # if len(ops_names) < 2: + # raise ValueError(f"trainer {self._name} cannot less than 2 policies assigned to it") + + self._ops_list = [ + RemoteOps(ops_name, dispatcher_address) if dispatcher_address else ops_creator[ops_name](ops_name) + for ops_name in ops_names + ] @property def num_policies(self) -> int: - return len(self._ops_list) + return len(self._ops_list) # TODO def record(self, transition_batch: MultiTransitionBatch) -> None: """ @@ -157,20 +161,20 @@ def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch: class BatchTrainer: - def __init__(self, trainers: List[Union[SingleTrainer, MultiTrainer]]): + def __init__(self, trainers: List[AbsTrainer]) -> None: self._trainers = trainers self._trainer_dict = {trainer.name: trainer for trainer in self._trainers} - def record(self, batch_by_trainer: dict): + def record(self, batch_by_trainer: Dict[str, AbsTransitionBatch]) -> None: for trainer_name, batch in batch_by_trainer.items(): self._trainer_dict[trainer_name].record(batch) - def train(self): + def train(self) -> None: try: - asyncio.run(self._train()) + asyncio.run(self._train_impl()) except TypeError: for trainer in self._trainers: trainer.train_step() - async def _train(self): + async def _train_impl(self) -> None: await asyncio.gather(*[trainer.train_step() for trainer in self._trainers]) diff --git a/maro/rl_v3/utils/__init__.py b/maro/rl_v3/utils/__init__.py index c66823123..3e801be09 100644 --- a/maro/rl_v3/utils/__init__.py +++ b/maro/rl_v3/utils/__init__.py @@ -1,9 +1,15 @@ +from typing import Union + from .objects import SHAPE_CHECK_FLAG from .torch_util import match_shape, ndarray_to_tensor from .transition_batch import MultiTransitionBatch, TransitionBatch + +AbsTransitionBatch = Union[TransitionBatch, MultiTransitionBatch] + + __all__ = [ "SHAPE_CHECK_FLAG", "match_shape", "ndarray_to_tensor", - "MultiTransitionBatch", "TransitionBatch" + "AbsTransitionBatch", "MultiTransitionBatch", "TransitionBatch" ] diff --git a/maro/rl_v3/utils/common.py b/maro/rl_v3/utils/common.py index 91be24a2e..d20ca36c4 100644 --- a/maro/rl_v3/utils/common.py +++ b/maro/rl_v3/utils/common.py @@ -4,12 +4,13 @@ import importlib import os import sys +from types import ModuleType from typing import List, Union from maro.utils import Logger -def from_env(var_name, required=True, default=None): +def from_env(var_name: str, required: bool = True, default: object = None) -> object: if var_name not in os.environ: if required: raise KeyError(f"Missing environment variable: {var_name}") @@ -26,7 +27,7 @@ def from_env(var_name, required=True, default=None): return var -def get_eval_schedule(sch: Union[int, List[int]], num_episodes: int): +def get_eval_schedule(sch: Union[int, List[int]], num_episodes: int) -> List[int]: """Helper function to the policy evaluation schedule. Args: @@ -51,21 +52,21 @@ def get_eval_schedule(sch: Union[int, List[int]], num_episodes: int): return schedule -def get_module(path: str): +def get_module(path: str) -> ModuleType: path = os.path.normpath(path) sys.path.insert(0, os.path.dirname(path)) return importlib.import_module(os.path.basename(path)) -def get_log_path(dir: str, job_name: str): +def get_log_path(dir: str, job_name: str) -> str: return os.path.join(dir, f"{job_name}.log") -def get_logger(dir: str, job_name: str, tag: str): +def get_logger(dir: str, job_name: str, tag: str) -> Logger: return Logger(tag, dump_path=get_log_path(dir, job_name), dump_mode="a") -def get_checkpoint_path(dir: str = None): +def get_checkpoint_path(dir: str = None) -> str: if dir: os.makedirs(dir, exist_ok=True) return dir