diff --git a/examples/cim/rl/algorithms/ac.py b/examples/cim/rl/algorithms/ac.py index 69bbc49c5..1769493df 100644 --- a/examples/cim/rl/algorithms/ac.py +++ b/examples/cim/rl/algorithms/ac.py @@ -54,9 +54,9 @@ def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyG def get_ac(state_dim: int, name: str) -> ActorCriticTrainer: return ActorCriticTrainer( name=name, + reward_discount=0.0, params=ActorCriticParams( get_v_critic_net_func=lambda: MyCriticNet(state_dim), - reward_discount=0.0, grad_iters=10, critic_loss_cls=torch.nn.SmoothL1Loss, min_logp=None, diff --git a/examples/cim/rl/algorithms/dqn.py b/examples/cim/rl/algorithms/dqn.py index c5999424a..d62e3443d 100644 --- a/examples/cim/rl/algorithms/dqn.py +++ b/examples/cim/rl/algorithms/dqn.py @@ -55,14 +55,14 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli def get_dqn(name: str) -> DQNTrainer: return DQNTrainer( name=name, + reward_discount=0.0, + replay_memory_capacity=10000, + batch_size=32, params=DQNParams( - reward_discount=0.0, update_target_every=5, num_epochs=10, soft_update_coef=0.1, double=False, - replay_memory_capacity=10000, random_overwrite=False, - batch_size=32, ), ) diff --git a/examples/cim/rl/algorithms/maddpg.py b/examples/cim/rl/algorithms/maddpg.py index 7d964f6bb..e6fd0a65b 100644 --- a/examples/cim/rl/algorithms/maddpg.py +++ b/examples/cim/rl/algorithms/maddpg.py @@ -62,8 +62,8 @@ def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePol def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer: return DiscreteMADDPGTrainer( name=name, + reward_discount=0.0, params=DiscreteMADDPGParams( - reward_discount=0.0, num_epoch=10, get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims), shared_critic=False, diff --git a/examples/cim/rl/algorithms/ppo.py b/examples/cim/rl/algorithms/ppo.py index d2e2df0d9..3600858f1 100644 --- a/examples/cim/rl/algorithms/ppo.py +++ b/examples/cim/rl/algorithms/ppo.py @@ -16,9 +16,9 @@ def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicy def get_ppo(state_dim: int, name: str) -> PPOTrainer: return PPOTrainer( name=name, + reward_discount=0.0, params=PPOParams( get_v_critic_net_func=lambda: MyCriticNet(state_dim), - reward_discount=0.0, grad_iters=10, critic_loss_cls=torch.nn.SmoothL1Loss, min_logp=None, diff --git a/examples/vm_scheduling/rl/algorithms/ac.py b/examples/vm_scheduling/rl/algorithms/ac.py index 411d35d6b..94d0afd63 100644 --- a/examples/vm_scheduling/rl/algorithms/ac.py +++ b/examples/vm_scheduling/rl/algorithms/ac.py @@ -61,9 +61,9 @@ def get_ac_policy(state_dim: int, action_num: int, num_features: int, name: str) def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer: return ActorCriticTrainer( name=name, + reward_discount=0.9, params=ActorCriticParams( get_v_critic_net_func=lambda: MyCriticNet(state_dim, num_features), - reward_discount=0.9, grad_iters=100, critic_loss_cls=torch.nn.MSELoss, min_logp=-20, diff --git a/examples/vm_scheduling/rl/algorithms/dqn.py b/examples/vm_scheduling/rl/algorithms/dqn.py index a94989418..499cb85b5 100644 --- a/examples/vm_scheduling/rl/algorithms/dqn.py +++ b/examples/vm_scheduling/rl/algorithms/dqn.py @@ -77,15 +77,15 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str def get_dqn(name: str) -> DQNTrainer: return DQNTrainer( name=name, + reward_discount=0.9, + replay_memory_capacity=10000, + batch_size=32, + data_parallelism=2, params=DQNParams( - reward_discount=0.9, update_target_every=5, num_epochs=100, soft_update_coef=0.1, double=False, - replay_memory_capacity=10000, random_overwrite=False, - batch_size=32, - data_parallelism=2, ), ) diff --git a/maro/rl/distributed/__init__.py b/maro/rl/distributed/__init__.py index b18d1ee59..828505c04 100644 --- a/maro/rl/distributed/__init__.py +++ b/maro/rl/distributed/__init__.py @@ -3,8 +3,12 @@ from .abs_proxy import AbsProxy from .abs_worker import AbsWorker +from .port_config import DEFAULT_ROLLOUT_PRODUCER_PORT, DEFAULT_TRAINING_BACKEND_PORT, DEFAULT_TRAINING_FRONTEND_PORT __all__ = [ "AbsProxy", "AbsWorker", + "DEFAULT_ROLLOUT_PRODUCER_PORT", + "DEFAULT_TRAINING_FRONTEND_PORT", + "DEFAULT_TRAINING_BACKEND_PORT", ] diff --git a/maro/rl/distributed/abs_worker.py b/maro/rl/distributed/abs_worker.py index 1f191034c..7da7e9435 100644 --- a/maro/rl/distributed/abs_worker.py +++ b/maro/rl/distributed/abs_worker.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. from abc import abstractmethod +from typing import Union import zmq from tornado.ioloop import IOLoop @@ -33,7 +34,7 @@ def __init__( super(AbsWorker, self).__init__() self._id = f"worker.{idx}" - self._logger = logger if logger else DummyLogger() + self._logger: Union[LoggerV2, DummyLogger] = logger if logger else DummyLogger() # ZMQ sockets and streams self._context = Context.instance() diff --git a/maro/rl/distributed/port_config.py b/maro/rl/distributed/port_config.py new file mode 100644 index 000000000..f0828c769 --- /dev/null +++ b/maro/rl/distributed/port_config.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +DEFAULT_ROLLOUT_PRODUCER_PORT = 20000 +DEFAULT_TRAINING_FRONTEND_PORT = 10000 +DEFAULT_TRAINING_BACKEND_PORT = 10001 diff --git a/maro/rl/exploration/scheduling.py b/maro/rl/exploration/scheduling.py index 1276171b1..3981729c9 100644 --- a/maro/rl/exploration/scheduling.py +++ b/maro/rl/exploration/scheduling.py @@ -98,14 +98,15 @@ def __init__( start_ep: int = 1, initial_value: float = None, ) -> None: + super().__init__(exploration_params, param_name, initial_value=initial_value) + # validate splits - splits = [(start_ep, initial_value)] + splits + [(last_ep, final_value)] + splits = [(start_ep, self._exploration_params[self.param_name])] + splits + [(last_ep, final_value)] splits.sort() for (ep, _), (ep2, _) in zip(splits, splits[1:]): if ep == ep2: raise ValueError("The zeroth element of split points must be unique") - super().__init__(exploration_params, param_name, initial_value=initial_value) self.final_value = final_value self._splits = splits self._ep = start_ep diff --git a/maro/rl/model/abs_net.py b/maro/rl/model/abs_net.py index 499eaa1d8..a559d1124 100644 --- a/maro/rl/model/abs_net.py +++ b/maro/rl/model/abs_net.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABCMeta -from typing import Any, Dict, Optional +from typing import Any, Dict import torch.nn from torch.optim import Optimizer @@ -18,7 +18,11 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta): def __init__(self) -> None: super(AbsNet, self).__init__() - self._optim: Optional[Optimizer] = None + @property + def optim(self) -> Optimizer: + optim = getattr(self, "_optim", None) + assert isinstance(optim, Optimizer), "Each AbsNet must have an optimizer" + return optim def step(self, loss: torch.Tensor) -> None: """Run a training step to update the net's parameters according to the given loss. @@ -26,9 +30,9 @@ def step(self, loss: torch.Tensor) -> None: Args: loss (torch.tensor): Loss used to update the model. """ - self._optim.zero_grad() + self.optim.zero_grad() loss.backward() - self._optim.step() + self.optim.step() def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: """Get the gradients with respect to all parameters according to the given loss. @@ -39,7 +43,7 @@ def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: Returns: Gradients (Dict[str, torch.Tensor]): A dict that contains gradients for all parameters. """ - self._optim.zero_grad() + self.optim.zero_grad() loss.backward() return {name: param.grad for name, param in self.named_parameters()} @@ -51,7 +55,7 @@ def apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None: """ for name, param in self.named_parameters(): param.grad = grad[name] - self._optim.step() + self.optim.step() def _forward_unimplemented(self, *input: Any) -> None: pass @@ -64,7 +68,7 @@ def get_state(self) -> dict: """ return { "network": self.state_dict(), - "optim": self._optim.state_dict(), + "optim": self.optim.state_dict(), } def set_state(self, net_state: dict) -> None: @@ -74,7 +78,7 @@ def set_state(self, net_state: dict) -> None: net_state (dict): A dict that contains the net's state. """ self.load_state_dict(net_state["network"]) - self._optim.load_state_dict(net_state["optim"]) + self.optim.load_state_dict(net_state["optim"]) def soft_update(self, other_model: AbsNet, tau: float) -> None: """Soft update the net's parameters according to another net, i.e., diff --git a/maro/rl/model/fc_block.py b/maro/rl/model/fc_block.py index acd7bd16e..aee712e18 100644 --- a/maro/rl/model/fc_block.py +++ b/maro/rl/model/fc_block.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from collections import OrderedDict -from typing import Any, List, Optional, Type +from typing import Any, List, Optional, Tuple, Type import torch import torch.nn as nn @@ -46,7 +46,7 @@ def __init__( skip_connection: bool = False, dropout_p: float = None, gradient_threshold: float = None, - name: str = None, + name: str = "NONAME", ) -> None: super(FullyConnected, self).__init__() self._input_dim = input_dim @@ -101,12 +101,12 @@ def input_dim(self) -> int: def output_dim(self) -> int: return self._output_dim - def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> torch.nn.Module: + def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> nn.Module: """Build a basic layer. BN -> Linear -> Activation -> Dropout """ - components = [] + components: List[Tuple[str, nn.Module]] = [] if self._batch_norm: components.append(("batch_norm", nn.BatchNorm1d(input_dim))) components.append(("linear", nn.Linear(input_dim, output_dim))) diff --git a/maro/rl/policy/abs_policy.py b/maro/rl/policy/abs_policy.py index c57c0db51..14b0bb3a9 100644 --- a/maro/rl/policy/abs_policy.py +++ b/maro/rl/policy/abs_policy.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -27,14 +27,14 @@ def __init__(self, name: str, trainable: bool) -> None: self._trainable = trainable @abstractmethod - def get_actions(self, states: object) -> object: + def get_actions(self, states: Union[list, np.ndarray]) -> Any: """Get actions according to states. Args: - states (object): States. + states (Union[list, np.ndarray]): States. Returns: - actions (object): Actions. + actions (Any): Actions. """ raise NotImplementedError @@ -79,7 +79,7 @@ class DummyPolicy(AbsPolicy): def __init__(self) -> None: super(DummyPolicy, self).__init__(name="DUMMY_POLICY", trainable=False) - def get_actions(self, states: object) -> None: + def get_actions(self, states: Union[list, np.ndarray]) -> None: return None def explore(self) -> None: @@ -101,11 +101,11 @@ class RuleBasedPolicy(AbsPolicy, metaclass=ABCMeta): def __init__(self, name: str) -> None: super(RuleBasedPolicy, self).__init__(name=name, trainable=False) - def get_actions(self, states: List[object]) -> List[object]: + def get_actions(self, states: list) -> list: return self._rule(states) @abstractmethod - def _rule(self, states: List[object]) -> List[object]: + def _rule(self, states: list) -> list: raise NotImplementedError def explore(self) -> None: @@ -304,7 +304,7 @@ def unfreeze(self) -> None: raise NotImplementedError @abstractmethod - def get_state(self) -> object: + def get_state(self) -> dict: """Get the state of the policy.""" raise NotImplementedError diff --git a/maro/rl/policy/continuous_rl_policy.py b/maro/rl/policy/continuous_rl_policy.py index e93cc982b..33ed3e55d 100644 --- a/maro/rl/policy/continuous_rl_policy.py +++ b/maro/rl/policy/continuous_rl_policy.py @@ -62,12 +62,10 @@ def __init__( ) self._lbounds, self._ubounds = _parse_action_range(self.action_dim, action_range) - assert self._lbounds is not None and self._ubounds is not None - self._policy_net = policy_net @property - def action_bounds(self) -> Tuple[List[float], List[float]]: + def action_bounds(self) -> Tuple[Optional[List[float]], Optional[List[float]]]: return self._lbounds, self._ubounds @property @@ -118,7 +116,7 @@ def eval(self) -> None: def train(self) -> None: self._policy_net.train() - def get_state(self) -> object: + def get_state(self) -> dict: return self._policy_net.get_state() def set_state(self, policy_state: dict) -> None: diff --git a/maro/rl/policy/discrete_rl_policy.py b/maro/rl/policy/discrete_rl_policy.py index a332908dc..567e9d054 100644 --- a/maro/rl/policy/discrete_rl_policy.py +++ b/maro/rl/policy/discrete_rl_policy.py @@ -85,9 +85,11 @@ def __init__( self._exploration_func = exploration_strategy[0] self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing - self._exploration_schedulers = [ - opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options - ] + self._exploration_schedulers = ( + [opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options] + if exploration_scheduling_options is not None + else [] + ) self._call_cnt = 0 self._warmup = warmup @@ -219,7 +221,7 @@ def eval(self) -> None: def train(self) -> None: self._q_net.train() - def get_state(self) -> object: + def get_state(self) -> dict: return self._q_net.get_state() def set_state(self, policy_state: dict) -> None: diff --git a/maro/rl/rl_component/rl_component_bundle.py b/maro/rl/rl_component/rl_component_bundle.py index 4654a8fa4..f85fe286b 100644 --- a/maro/rl/rl_component/rl_component_bundle.py +++ b/maro/rl/rl_component/rl_component_bundle.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import dataclass from typing import Any, Dict, List from maro.rl.policy import AbsPolicy, RLPolicy @@ -9,7 +8,6 @@ from maro.rl.training import AbsTrainer -@dataclass class RLComponentBundle: """Bundle of all necessary components to run a RL job in MARO. @@ -27,15 +25,20 @@ class RLComponentBundle: mapping will not be trained. """ - env_sampler: AbsEnvSampler - agent2policy: Dict[Any, str] - policies: List[AbsPolicy] - trainers: List[AbsTrainer] - device_mapping: Dict[str, str] = None - policy_trainer_mapping: Dict[str, str] = None + def __init__( + self, + env_sampler: AbsEnvSampler, + agent2policy: Dict[Any, str], + policies: List[AbsPolicy], + trainers: List[AbsTrainer], + device_mapping: Dict[str, str] = None, + policy_trainer_mapping: Dict[str, str] = None, + ) -> None: + self.env_sampler = env_sampler + self.agent2policy = agent2policy + self.policies = policies + self.trainers = trainers - def __post_init__(self) -> None: - # Check missing policies policy_set = set([policy.name for policy in self.policies]) not_found = [policy_name for policy_name in self.agent2policy.values() if policy_name not in policy_set] if len(not_found) > 0: @@ -51,14 +54,14 @@ def __post_init__(self) -> None: self.policies = kept_policies policy_set = set([policy.name for policy in self.policies]) - if self.device_mapping is not None: - self.device_mapping = {k: v for k, v in self.device_mapping.items() if k in policy_set} - else: - self.device_mapping = {} - - # Create default policy-trainer mapping if not provided - if self.policy_trainer_mapping is None: # Default policy-trainer naming rule - self.policy_trainer_mapping = {policy_name: policy_name.split(".")[0] for policy_name in policy_set} + self.device_mapping = ( + {k: v for k, v in device_mapping.items() if k in policy_set} if device_mapping is not None else {} + ) + self.policy_trainer_mapping = ( + policy_trainer_mapping + if policy_trainer_mapping is not None + else {policy_name: policy_name.split(".")[0] for policy_name in policy_set} + ) # Check missing trainers self.policy_trainer_mapping = { diff --git a/maro/rl/rollout/batch_env_sampler.py b/maro/rl/rollout/batch_env_sampler.py index 27adb6eae..a3504a156 100644 --- a/maro/rl/rollout/batch_env_sampler.py +++ b/maro/rl/rollout/batch_env_sampler.py @@ -4,12 +4,13 @@ import os import time from itertools import chain -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch import zmq from zmq import Context, Poller +from maro.rl.distributed import DEFAULT_ROLLOUT_PRODUCER_PORT from maro.rl.utils.common import bytes_to_pyobj, get_own_ip_address, pyobj_to_bytes from maro.rl.utils.objects import FILE_SUFFIX from maro.utils import DummyLogger, LoggerV2 @@ -37,19 +38,19 @@ def __init__(self, port: int = 20000, logger: LoggerV2 = None) -> None: self._poller = Poller() self._poller.register(self._task_endpoint, zmq.POLLIN) - self._workers = set() - self._logger = logger + self._workers: set = set() + self._logger: Union[DummyLogger, LoggerV2] = logger if logger is not None else DummyLogger() def _wait_for_workers_ready(self, k: int) -> None: while len(self._workers) < k: self._workers.add(self._task_endpoint.recv_multipart()[0]) - def _recv_result_for_target_index(self, index: int) -> object: + def _recv_result_for_target_index(self, index: int) -> Any: rep = bytes_to_pyobj(self._task_endpoint.recv_multipart()[-1]) assert isinstance(rep, dict) return rep["result"] if rep["index"] == index else None - def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: int = None) -> List[dict]: + def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: float = None) -> List[dict]: """Send a task request to a set of remote workers and collect the results. Args: @@ -70,7 +71,7 @@ def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_fa min_replies = parallelism start_time = time.time() - results = [] + results: list = [] for worker_id in list(self._workers)[:parallelism]: self._task_endpoint.send_multipart([worker_id, pyobj_to_bytes(req)]) self._logger.debug(f"Sent {parallelism} roll-out requests...") @@ -81,7 +82,7 @@ def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_fa results.append(result) if grace_factor is not None: - countdown = int((time.time() - start_time) * grace_factor) * 1000 # milliseconds + countdown = int((time.time() - start_time) * grace_factor) * 1000.0 # milliseconds self._logger.debug(f"allowing {countdown / 1000} seconds for remaining results") while len(results) < parallelism and countdown > 0: start = time.time() @@ -125,15 +126,18 @@ class BatchEnvSampler: def __init__( self, sampling_parallelism: int, - port: int = 20000, + port: int = None, min_env_samples: int = None, grace_factor: float = None, eval_parallelism: int = None, logger: LoggerV2 = None, ) -> None: super(BatchEnvSampler, self).__init__() - self._logger = logger if logger else DummyLogger() - self._controller = ParallelTaskController(port=port, logger=logger) + self._logger: Union[LoggerV2, DummyLogger] = logger if logger is not None else DummyLogger() + self._controller = ParallelTaskController( + port=port if port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT, + logger=logger, + ) self._sampling_parallelism = 1 if sampling_parallelism is None else sampling_parallelism self._min_env_samples = min_env_samples if min_env_samples is not None else self._sampling_parallelism @@ -143,11 +147,15 @@ def __init__( self._ep = 0 self._end_of_episode = True - def sample(self, policy_state: Optional[Dict[str, object]] = None, num_steps: Optional[int] = None) -> dict: + def sample( + self, + policy_state: Optional[Dict[str, Dict[str, Any]]] = None, + num_steps: Optional[int] = None, + ) -> dict: """Collect experiences from a set of remote roll-out workers. Args: - policy_state (Dict[str, object]): Policy state dict. If it is not None, then we need to update all + policy_state (Dict[str, Any]): Policy state dict. If it is not None, then we need to update all policies according to the latest policy states, then start the experience collection. num_steps (Optional[int], default=None): Number of environment steps to collect experiences for. If it is None, interactions with the (remote) environments will continue until the terminal state is @@ -181,7 +189,7 @@ def sample(self, policy_state: Optional[Dict[str, object]] = None, num_steps: Op "info": [res["info"][0] for res in results], } - def eval(self, policy_state: Dict[str, object] = None) -> dict: + def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict: req = {"type": "eval", "policy_state": policy_state, "index": self._ep} # -1 signals test results = self._controller.collect(req, self._eval_parallelism) return { diff --git a/maro/rl/rollout/env_sampler.py b/maro/rl/rollout/env_sampler.py index 9ce5f1cbc..81e6ccdad 100644 --- a/maro/rl/rollout/env_sampler.py +++ b/maro/rl/rollout/env_sampler.py @@ -47,16 +47,16 @@ def set_policy_state(self, policy_state_dict: Dict[str, dict]) -> None: def choose_actions( self, - state_by_agent: Dict[Any, Union[np.ndarray, List[object]]], - ) -> Dict[Any, Union[np.ndarray, List[object]]]: + state_by_agent: Dict[Any, Union[np.ndarray, list]], + ) -> Dict[Any, Union[np.ndarray, list]]: """Choose action according to the given (observable) states of all agents. Args: - state_by_agent (Dict[Any, Union[np.ndarray, List[object]]]): Dictionary containing each agent's states. + state_by_agent (Dict[Any, Union[np.ndarray, list]]): Dictionary containing each agent's states. If the policy is a `RLPolicy`, its state is a Numpy array. Otherwise, its state is a list of objects. Returns: - actions (Dict[Any, Union[np.ndarray, List[object]]]): Dict that contains the action for all agents. + actions (Dict[Any, Union[np.ndarray, list]]): Dict that contains the action for all agents. If the policy is a `RLPolicy`, its action is a Numpy array. Otherwise, its action is a list of objects. """ self.switch_to_eval_mode() @@ -67,8 +67,8 @@ def choose_actions( @abstractmethod def _choose_actions_impl( self, - state_by_agent: Dict[Any, Union[np.ndarray, List[object]]], - ) -> Dict[Any, Union[np.ndarray, List[object]]]: + state_by_agent: Dict[Any, Union[np.ndarray, list]], + ) -> Dict[Any, Union[np.ndarray, list]]: """Implementation of `choose_actions`.""" raise NotImplementedError @@ -91,15 +91,15 @@ def switch_to_eval_mode(self) -> None: class SimpleAgentWrapper(AbsAgentWrapper): def __init__( self, - policy_dict: Dict[str, RLPolicy], # {policy_name: RLPolicy} + policy_dict: Dict[str, AbsPolicy], # {policy_name: AbsPolicy} agent2policy: Dict[Any, str], # {agent_name: policy_name} ) -> None: super(SimpleAgentWrapper, self).__init__(policy_dict=policy_dict, agent2policy=agent2policy) def _choose_actions_impl( self, - state_by_agent: Dict[Any, Union[np.ndarray, List[object]]], - ) -> Dict[Any, Union[np.ndarray, List[object]]]: + state_by_agent: Dict[Any, Union[np.ndarray, list]], + ) -> Dict[Any, Union[np.ndarray, list]]: # Aggregate states by policy states_by_policy = collections.defaultdict(list) # {str: list of np.ndarray} agents_by_policy = collections.defaultdict(list) # {str: list of str} @@ -108,15 +108,15 @@ def _choose_actions_impl( states_by_policy[policy_name].append(state) agents_by_policy[policy_name].append(agent_name) - action_dict = {} + action_dict: dict = {} for policy_name in agents_by_policy: policy = self._policy_dict[policy_name] if isinstance(policy, RLPolicy): states = np.vstack(states_by_policy[policy_name]) # np.ndarray else: - states = states_by_policy[policy_name] # List[object] - actions = policy.get_actions(states) # np.ndarray or List[object] + states = states_by_policy[policy_name] # list + actions: Union[np.ndarray, list] = policy.get_actions(states) # np.ndarray or list action_dict.update(zip(agents_by_policy[policy_name], actions)) return action_dict @@ -184,7 +184,7 @@ def split_contents_by_trainer(self, agent2trainer: Dict[Any, str]) -> Dict[str, Contents (Dict[str, ExpElement]): A dict that contains the ExpElements of all trainers. The key of this dict is the trainer name. """ - ret = collections.defaultdict( + ret: Dict[str, ExpElement] = collections.defaultdict( lambda: ExpElement( tick=self.tick, state=self.state, @@ -209,7 +209,7 @@ def split_contents_by_trainer(self, agent2trainer: Dict[Any, str]) -> Dict[str, @dataclass class CacheElement(ExpElement): - event: object + event: Any env_action_dict: Dict[Any, np.ndarray] def make_exp_element(self) -> ExpElement: @@ -257,7 +257,7 @@ def __init__( self._agent_wrapper_cls = agent_wrapper_cls - self._event = None + self._event: Optional[list] = None self._end_of_episode = True self._state: Optional[np.ndarray] = None self._agent_state_dict: Dict[Any, np.ndarray] = {} @@ -266,7 +266,7 @@ def __init__( self._agent_last_index: Dict[Any, int] = {} # Index of last occurrence of agent in self._trans_cache self._reward_eval_delay = reward_eval_delay - self._info = {} + self._info: dict = {} assert self._reward_eval_delay is None or self._reward_eval_delay >= 0 @@ -291,23 +291,31 @@ def __init__( [policy_name in self._rl_policy_dict for policy_name in self._trainable_policies], ), "All trainable policies must be RL policies!" + @property + def env(self) -> Env: + assert self._env is not None + return self._env + + def _switch_env(self, env: Env) -> None: + self._env = env + def assign_policy_to_device(self, policy_name: str, device: torch.device) -> None: self._rl_policy_dict[policy_name].to_device(device) def _get_global_and_agent_state( self, - event: object, + event: Any, tick: int = None, - ) -> Tuple[Optional[object], Dict[Any, Union[np.ndarray, List[object]]]]: + ) -> Tuple[Optional[Any], Dict[Any, Union[np.ndarray, list]]]: """Get the global and individual agents' states. Args: - event (object): Event. + event (Any): Event. tick (int, default=None): Current tick. Returns: - Global state (Optional[object]) - Dict of agent states (Dict[Any, Union[np.ndarray, List[object]]]). If the policy is a `RLPolicy`, + Global state (Optional[Any]) + Dict of agent states (Dict[Any, Union[np.ndarray, list]]). If the policy is a `RLPolicy`, its state is a Numpy array. Otherwise, its state is a list of objects. """ global_state, agent_state_dict = self._get_global_and_agent_state_impl(event, tick) @@ -321,23 +329,23 @@ def _get_global_and_agent_state( @abstractmethod def _get_global_and_agent_state_impl( self, - event: object, + event: Any, tick: int = None, - ) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]: + ) -> Tuple[Union[None, np.ndarray, list], Dict[Any, Union[np.ndarray, list]]]: raise NotImplementedError @abstractmethod def _translate_to_env_action( self, - action_dict: Dict[Any, Union[np.ndarray, List[object]]], - event: object, - ) -> Dict[Any, object]: + action_dict: Dict[Any, Union[np.ndarray, list]], + event: Any, + ) -> dict: """Translate model-generated actions into an object that can be executed by the env. Args: - action_dict (Dict[Any, Union[np.ndarray, List[object]]]): Action for all agents. If the policy is a + action_dict (Dict[Any, Union[np.ndarray, list]]): Action for all agents. If the policy is a `RLPolicy`, its (input) action is a Numpy array. Otherwise, its (input) action is a list of objects. - event (object): Decision event. + event (Any): Decision event. Returns: A dict that contains env actions for all agents. @@ -345,12 +353,12 @@ def _translate_to_env_action( raise NotImplementedError @abstractmethod - def _get_reward(self, env_action_dict: Dict[Any, object], event: object, tick: int) -> Dict[Any, float]: + def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]: """Get rewards according to the env actions. Args: - env_action_dict (Dict[Any, object]): Dict that contains env actions for all agents. - event (object): Decision event. + env_action_dict (dict): Dict that contains env actions for all agents. + event (Any): Decision event. tick (int): Current tick. Returns: @@ -359,7 +367,7 @@ def _get_reward(self, env_action_dict: Dict[Any, object], event: object, tick: i raise NotImplementedError def _step(self, actions: Optional[list]) -> None: - _, self._event, self._end_of_episode = self._env.step(actions) + _, self._event, self._end_of_episode = self.env.step(actions) self._state, self._agent_state_dict = ( (None, {}) if self._end_of_episode else self._get_global_and_agent_state(self._event) ) @@ -397,7 +405,7 @@ def _append_cache_element(self, cache_element: Optional[CacheElement]) -> None: self._agent_last_index[agent_name] = cur_index def _reset(self) -> None: - self._env.reset() + self.env.reset() self._info.clear() self._trans_cache.clear() self._agent_last_index.clear() @@ -406,7 +414,11 @@ def _reset(self) -> None: def _select_trainable_agents(self, original_dict: dict) -> dict: return {k: v for k, v in original_dict.items() if k in self._trainable_agents} - def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Optional[int] = None) -> dict: + def sample( + self, + policy_state: Optional[Dict[str, Dict[str, Any]]] = None, + num_steps: Optional[int] = None, + ) -> dict: """Sample experiences. Args: @@ -419,7 +431,7 @@ def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Opti A dict that contains the collected experiences and additional information. """ # Init the env - self._env = self._learn_env + self._switch_env(self._learn_env) if self._end_of_episode: self._reset() @@ -437,7 +449,7 @@ def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Opti # Store experiences in the cache cache_element = CacheElement( - tick=self._env.tick, + tick=self.env.tick, event=self._event, state=self._state, agent_state_dict=self._select_trainable_agents(self._agent_state_dict), @@ -460,7 +472,7 @@ def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Opti steps_to_go -= 1 self._append_cache_element(None) - tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) + tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) experiences: List[ExpElement] = [] while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound: cache_element = self._trans_cache.pop(0) @@ -502,8 +514,8 @@ def load_policy_state(self, path: str) -> List[str]: return loaded - def eval(self, policy_state: Dict[str, dict] = None) -> dict: - self._env = self._test_env + def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict: + self._switch_env(self._test_env) self._reset() if policy_state is not None: self.set_policy_state(policy_state) @@ -515,7 +527,7 @@ def eval(self, policy_state: Dict[str, dict] = None) -> dict: # Store experiences in the cache cache_element = CacheElement( - tick=self._env.tick, + tick=self.env.tick, event=self._event, state=self._state, agent_state_dict=self._select_trainable_agents(self._agent_state_dict), @@ -538,7 +550,7 @@ def eval(self, policy_state: Dict[str, dict] = None) -> dict: self._append_cache_element(cache_element) self._append_cache_element(None) - tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) + tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound: cache_element = self._trans_cache.pop(0) if self._reward_eval_delay is not None: diff --git a/maro/rl/rollout/worker.py b/maro/rl/rollout/worker.py index ec4089195..b8301ee38 100644 --- a/maro/rl/rollout/worker.py +++ b/maro/rl/rollout/worker.py @@ -5,7 +5,7 @@ import typing -from maro.rl.distributed import AbsWorker +from maro.rl.distributed import DEFAULT_ROLLOUT_PRODUCER_PORT, AbsWorker from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes from maro.utils import LoggerV2 @@ -30,13 +30,13 @@ def __init__( idx: int, rl_component_bundle: RLComponentBundle, producer_host: str, - producer_port: int = 20000, + producer_port: int = None, logger: LoggerV2 = None, ) -> None: super(RolloutWorker, self).__init__( idx=idx, producer_host=producer_host, - producer_port=producer_port, + producer_port=producer_port if producer_port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT, logger=logger, ) self._env_sampler = rl_component_bundle.env_sampler @@ -54,18 +54,19 @@ def _compute(self, msg: list) -> None: req = bytes_to_pyobj(msg[-1]) assert isinstance(req, dict) assert req["type"] in {"sample", "eval", "set_policy_state", "post_collect", "post_evaluate"} - if req["type"] == "sample": - result = self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"]) - elif req["type"] == "eval": - result = self._env_sampler.eval(policy_state=req["policy_state"]) - elif req["type"] == "set_policy_state": - self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"]) - result = True - elif req["type"] == "post_collect": - self._env_sampler.post_collect(info_list=req["info_list"], ep=req["index"]) - result = True - else: - self._env_sampler.post_evaluate(info_list=req["info_list"], ep=req["index"]) - result = True - self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]})) + if req["type"] in ("sample", "eval"): + result = ( + self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"]) + if req["type"] == "sample" + else self._env_sampler.eval(policy_state=req["policy_state"]) + ) + self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]})) + else: + if req["type"] == "set_policy_state": + self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"]) + elif req["type"] == "post_collect": + self._env_sampler.post_collect(info_list=req["info_list"], ep=req["index"]) + else: + self._env_sampler.post_evaluate(info_list=req["info_list"], ep=req["index"]) + self._stream.send(pyobj_to_bytes({"result": True, "index": req["index"]})) diff --git a/maro/rl/training/__init__.py b/maro/rl/training/__init__.py index 3f2d01a4c..a77296f98 100644 --- a/maro/rl/training/__init__.py +++ b/maro/rl/training/__init__.py @@ -4,7 +4,7 @@ from .proxy import TrainingProxy from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory from .train_ops import AbsTrainOps, RemoteOps, remote -from .trainer import AbsTrainer, MultiAgentTrainer, SingleAgentTrainer, TrainerParams +from .trainer import AbsTrainer, BaseTrainerParams, MultiAgentTrainer, SingleAgentTrainer from .training_manager import TrainingManager from .worker import TrainOpsWorker @@ -18,9 +18,9 @@ "RemoteOps", "remote", "AbsTrainer", + "BaseTrainerParams", "MultiAgentTrainer", "SingleAgentTrainer", - "TrainerParams", "TrainingManager", "TrainOpsWorker", ] diff --git a/maro/rl/training/algorithms/ac.py b/maro/rl/training/algorithms/ac.py index 2f9d576e2..4486daee3 100644 --- a/maro/rl/training/algorithms/ac.py +++ b/maro/rl/training/algorithms/ac.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Dict from maro.rl.training.algorithms.base import ACBasedParams, ACBasedTrainer @@ -13,18 +12,8 @@ class ActorCriticParams(ACBasedParams): for detailed information. """ - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_v_critic_net_func": self.get_v_critic_net_func, - "reward_discount": self.reward_discount, - "critic_loss_cls": self.critic_loss_cls, - "lam": self.lam, - "min_logp": self.min_logp, - "is_discrete_action": self.is_discrete_action, - } - def __post_init__(self) -> None: - assert self.get_v_critic_net_func is not None + assert self.clip_ratio is None class ActorCriticTrainer(ACBasedTrainer): @@ -34,5 +23,20 @@ class ActorCriticTrainer(ACBasedTrainer): https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/vpg """ - def __init__(self, name: str, params: ActorCriticParams) -> None: - super(ActorCriticTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: ActorCriticParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(ActorCriticTrainer, self).__init__( + name, + params, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) diff --git a/maro/rl/training/algorithms/base/ac_ppo_base.py b/maro/rl/training/algorithms/base/ac_ppo_base.py index e97b0f6e3..3227437be 100644 --- a/maro/rl/training/algorithms/base/ac_ppo_base.py +++ b/maro/rl/training/algorithms/base/ac_ppo_base.py @@ -3,19 +3,19 @@ from abc import ABCMeta from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, cast import numpy as np import torch from maro.rl.model import VNet from maro.rl.policy import ContinuousRLPolicy, DiscretePolicyGradient, RLPolicy -from maro.rl.training import AbsTrainOps, FIFOReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, FIFOReplayMemory, RemoteOps, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, discount_cumsum, get_torch_device, ndarray_to_tensor @dataclass -class ACBasedParams(TrainerParams, metaclass=ABCMeta): +class ACBasedParams(BaseTrainerParams, metaclass=ABCMeta): """ Parameter bundle for Actor-Critic based algorithms (Actor-Critic & PPO) @@ -23,18 +23,16 @@ class ACBasedParams(TrainerParams, metaclass=ABCMeta): grad_iters (int, default=1): Number of iterations to calculate gradients. critic_loss_cls (Callable, default=None): Critic loss function. If it is None, use MSE. lam (float, default=0.9): Lambda value for generalized advantage estimation (TD-Lambda). - min_logp (float, default=None): Lower bound for clamping logP values during learning. + min_logp (float, default=float("-inf")): Lower bound for clamping logP values during learning. This is to prevent logP from becoming very large in magnitude and causing stability issues. - If it is None, it means no lower bound. - is_discrete_action (bool, default=True): Indicator of continuous or discrete action policy. """ - get_v_critic_net_func: Callable[[], VNet] = None + get_v_critic_net_func: Callable[[], VNet] grad_iters: int = 1 - critic_loss_cls: Callable = None + critic_loss_cls: Optional[Callable] = None lam: float = 0.9 - min_logp: Optional[float] = None - is_discrete_action: bool = True + min_logp: float = float("-inf") + clip_ratio: Optional[float] = None class ACBasedOps(AbsTrainOps): @@ -44,14 +42,9 @@ def __init__( self, name: str, policy: RLPolicy, - get_v_critic_net_func: Callable[[], VNet], - parallelism: int = 1, + params: ACBasedParams, reward_discount: float = 0.9, - critic_loss_cls: Callable = None, - clip_ratio: float = None, - lam: float = 0.9, - min_logp: float = None, - is_discrete_action: bool = True, + parallelism: int = 1, ) -> None: super(ACBasedOps, self).__init__( name=name, @@ -59,15 +52,15 @@ def __init__( parallelism=parallelism, ) - assert isinstance(self._policy, DiscretePolicyGradient) or isinstance(self._policy, ContinuousRLPolicy) + assert isinstance(self._policy, (ContinuousRLPolicy, DiscretePolicyGradient)) self._reward_discount = reward_discount - self._critic_loss_func = critic_loss_cls() if critic_loss_cls is not None else torch.nn.MSELoss() - self._clip_ratio = clip_ratio - self._lam = lam - self._min_logp = min_logp - self._v_critic_net = get_v_critic_net_func() - self._is_discrete_action = is_discrete_action + self._critic_loss_func = params.critic_loss_cls() if params.critic_loss_cls is not None else torch.nn.MSELoss() + self._clip_ratio = params.clip_ratio + self._lam = params.lam + self._min_logp = params.min_logp + self._v_critic_net = params.get_v_critic_net_func() + self._is_discrete_action = isinstance(self._policy, DiscretePolicyGradient) def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor: """Compute the critic loss of the batch. @@ -247,14 +240,32 @@ class ACBasedTrainer(SingleAgentTrainer): https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f """ - def __init__(self, name: str, params: ACBasedParams) -> None: - super(ACBasedTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: ACBasedParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(ACBasedTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params + def _register_policy(self, policy: RLPolicy) -> None: + assert isinstance(policy, (ContinuousRLPolicy, DiscretePolicyGradient)) + self._policy = policy + def build(self) -> None: - self._ops = self.get_ops() + self._ops = cast(ACBasedOps, self.get_ops()) self._replay_memory = FIFOReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._ops.policy_state_dim, action_dim=self._ops.policy_action_dim, ) @@ -266,8 +277,9 @@ def get_local_ops(self) -> AbsTrainOps: return ACBasedOps( name=self._policy.name, policy=self._policy, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def _get_batch(self) -> TransitionBatch: diff --git a/maro/rl/training/algorithms/ddpg.py b/maro/rl/training/algorithms/ddpg.py index f7a5da1eb..79bd5b336 100644 --- a/maro/rl/training/algorithms/ddpg.py +++ b/maro/rl/training/algorithms/ddpg.py @@ -2,19 +2,19 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Callable, Dict +from typing import Callable, Dict, Optional, cast import torch from maro.rl.model import QNet from maro.rl.policy import ContinuousRLPolicy, RLPolicy -from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone @dataclass -class DDPGParams(TrainerParams): +class DDPGParams(BaseTrainerParams): """ get_q_critic_net_func (Callable[[], QNet]): Function to get Q critic net. num_epochs (int, default=1): Number of training epochs per call to ``learn``. @@ -30,25 +30,14 @@ class DDPGParams(TrainerParams): min_num_to_trigger_training (int, default=0): Minimum number required to start training. """ - get_q_critic_net_func: Callable[[], QNet] = None + get_q_critic_net_func: Callable[[], QNet] num_epochs: int = 1 update_target_every: int = 5 - q_value_loss_cls: Callable = None + q_value_loss_cls: Optional[Callable] = None soft_update_coef: float = 1.0 random_overwrite: bool = False min_num_to_trigger_training: int = 0 - def __post_init__(self) -> None: - assert self.get_q_critic_net_func is not None - - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_q_critic_net_func": self.get_q_critic_net_func, - "reward_discount": self.reward_discount, - "q_value_loss_cls": self.q_value_loss_cls, - "soft_update_coef": self.soft_update_coef, - } - class DDPGOps(AbsTrainOps): """DDPG algorithm implementation. Reference: https://spinningup.openai.com/en/latest/algorithms/ddpg.html""" @@ -57,11 +46,9 @@ def __init__( self, name: str, policy: RLPolicy, - get_q_critic_net_func: Callable[[], QNet], - reward_discount: float, + params: DDPGParams, + reward_discount: float = 0.9, parallelism: int = 1, - q_value_loss_cls: Callable = None, - soft_update_coef: float = 1.0, ) -> None: super(DDPGOps, self).__init__( name=name, @@ -71,16 +58,18 @@ def __init__( assert isinstance(self._policy, ContinuousRLPolicy) - self._target_policy = clone(self._policy) + self._target_policy: ContinuousRLPolicy = clone(self._policy) self._target_policy.set_name(f"target_{self._policy.name}") self._target_policy.eval() - self._q_critic_net = get_q_critic_net_func() + self._q_critic_net = params.get_q_critic_net_func() self._target_q_critic_net: QNet = clone(self._q_critic_net) self._target_q_critic_net.eval() self._reward_discount = reward_discount - self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() - self._soft_update_coef = soft_update_coef + self._q_value_loss_func = ( + params.q_value_loss_cls() if params.q_value_loss_cls is not None else torch.nn.MSELoss() + ) + self._soft_update_coef = params.soft_update_coef def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor: """Compute the critic loss of the batch. @@ -207,7 +196,7 @@ def soft_update_target(self) -> None: self._target_policy.soft_update(self._policy, self._soft_update_coef) self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef) - def to_device(self, device: str) -> None: + def to_device(self, device: str = None) -> None: self._device = get_torch_device(device=device) self._policy.to_device(self._device) self._target_policy.to_device(self._device) @@ -223,21 +212,39 @@ class DDPGTrainer(SingleAgentTrainer): https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg """ - def __init__(self, name: str, params: DDPGParams) -> None: - super(DDPGTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: DDPGParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(DDPGTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params self._policy_version = self._target_policy_version = 0 self._memory_size = 0 def build(self) -> None: - self._ops = self.get_ops() + self._ops = cast(DDPGOps, self.get_ops()) self._replay_memory = RandomReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._ops.policy_state_dim, action_dim=self._ops.policy_action_dim, random_overwrite=self._params.random_overwrite, ) + def _register_policy(self, policy: RLPolicy) -> None: + assert isinstance(policy, ContinuousRLPolicy) + self._policy = policy + def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: return transition_batch @@ -245,8 +252,9 @@ def get_local_ops(self) -> AbsTrainOps: return DDPGOps( name=self._policy.name, policy=self._policy, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def _get_batch(self, batch_size: int = None) -> TransitionBatch: diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py index 5ad2bf5e8..5a4f938ab 100644 --- a/maro/rl/training/algorithms/dqn.py +++ b/maro/rl/training/algorithms/dqn.py @@ -2,18 +2,18 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Dict +from typing import Dict, cast import torch from maro.rl.policy import RLPolicy, ValueBasedPolicy -from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone @dataclass -class DQNParams(TrainerParams): +class DQNParams(BaseTrainerParams): """ num_epochs (int, default=1): Number of training epochs. update_target_every (int, default=5): Number of gradient steps between target model updates. @@ -33,23 +33,15 @@ class DQNParams(TrainerParams): double: bool = False random_overwrite: bool = False - def extract_ops_params(self) -> Dict[str, object]: - return { - "reward_discount": self.reward_discount, - "soft_update_coef": self.soft_update_coef, - "double": self.double, - } - class DQNOps(AbsTrainOps): def __init__( self, name: str, policy: RLPolicy, - parallelism: int = 1, + params: DQNParams, reward_discount: float = 0.9, - soft_update_coef: float = 0.1, - double: bool = False, + parallelism: int = 1, ) -> None: super(DQNOps, self).__init__( name=name, @@ -60,8 +52,8 @@ def __init__( assert isinstance(self._policy, ValueBasedPolicy) self._reward_discount = reward_discount - self._soft_update_coef = soft_update_coef - self._double = double + self._soft_update_coef = params.soft_update_coef + self._double = params.double self._loss_func = torch.nn.MSELoss() self._target_policy: ValueBasedPolicy = clone(self._policy) @@ -143,7 +135,7 @@ def soft_update_target(self) -> None: """Soft update the target policy.""" self._target_policy.soft_update(self._policy, self._soft_update_coef) - def to_device(self, device: str) -> None: + def to_device(self, device: str = None) -> None: self._device = get_torch_device(device) self._policy.to_device(self._device) self._target_policy.to_device(self._device) @@ -155,20 +147,38 @@ class DQNTrainer(SingleAgentTrainer): See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details. """ - def __init__(self, name: str, params: DQNParams) -> None: - super(DQNTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: DQNParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(DQNTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params self._q_net_version = self._target_q_net_version = 0 def build(self) -> None: - self._ops = self.get_ops() + self._ops = cast(DQNOps, self.get_ops()) self._replay_memory = RandomReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._ops.policy_state_dim, action_dim=self._ops.policy_action_dim, random_overwrite=self._params.random_overwrite, ) + def _register_policy(self, policy: RLPolicy) -> None: + assert isinstance(policy, ValueBasedPolicy) + self._policy = policy + def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: return transition_batch @@ -176,8 +186,9 @@ def get_local_ops(self) -> AbsTrainOps: return DQNOps( name=self._policy.name, policy=self._policy, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def _get_batch(self, batch_size: int = None) -> TransitionBatch: diff --git a/maro/rl/training/algorithms/maddpg.py b/maro/rl/training/algorithms/maddpg.py index bddc6eb70..edc63f39a 100644 --- a/maro/rl/training/algorithms/maddpg.py +++ b/maro/rl/training/algorithms/maddpg.py @@ -4,7 +4,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -12,14 +12,21 @@ from maro.rl.model import MultiQNet from maro.rl.policy import DiscretePolicyGradient, RLPolicy from maro.rl.rollout import ExpElement -from maro.rl.training import AbsTrainOps, MultiAgentTrainer, RandomMultiReplayMemory, RemoteOps, TrainerParams, remote +from maro.rl.training import ( + AbsTrainOps, + BaseTrainerParams, + MultiAgentTrainer, + RandomMultiReplayMemory, + RemoteOps, + remote, +) from maro.rl.utils import MultiTransitionBatch, get_torch_device, ndarray_to_tensor from maro.rl.utils.objects import FILE_SUFFIX from maro.utils import clone @dataclass -class DiscreteMADDPGParams(TrainerParams): +class DiscreteMADDPGParams(BaseTrainerParams): """ get_q_critic_net_func (Callable[[], MultiQNet]): Function to get multi Q critic net. num_epochs (int, default=10): Number of training epochs. @@ -30,40 +37,24 @@ class DiscreteMADDPGParams(TrainerParams): shared_critic (bool, default=False): Whether different policies use shared critic or individual policies. """ - get_q_critic_net_func: Callable[[], MultiQNet] = None + get_q_critic_net_func: Callable[[], MultiQNet] num_epoch: int = 10 update_target_every: int = 5 soft_update_coef: float = 0.5 - q_value_loss_cls: Callable = None + q_value_loss_cls: Optional[Callable] = None shared_critic: bool = False - def __post_init__(self) -> None: - assert self.get_q_critic_net_func is not None - - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_q_critic_net_func": self.get_q_critic_net_func, - "shared_critic": self.shared_critic, - "reward_discount": self.reward_discount, - "soft_update_coef": self.soft_update_coef, - "update_target_every": self.update_target_every, - "q_value_loss_func": self.q_value_loss_cls() if self.q_value_loss_cls is not None else torch.nn.MSELoss(), - } - class DiscreteMADDPGOps(AbsTrainOps): def __init__( self, name: str, policy: RLPolicy, - get_q_critic_net_func: Callable[[], MultiQNet], + param: DiscreteMADDPGParams, + shared_critic: bool, policy_idx: int, parallelism: int = 1, - shared_critic: bool = False, reward_discount: float = 0.9, - soft_update_coef: float = 0.5, - update_target_every: int = 5, - q_value_loss_func: Callable = None, ) -> None: super(DiscreteMADDPGOps, self).__init__( name=name, @@ -82,14 +73,14 @@ def __init__( self._target_policy.eval() # Critic - self._q_critic_net: MultiQNet = get_q_critic_net_func() + self._q_critic_net: MultiQNet = param.get_q_critic_net_func() self._target_q_critic_net: MultiQNet = clone(self._q_critic_net) self._target_q_critic_net.eval() self._reward_discount = reward_discount - self._q_value_loss_func = q_value_loss_func - self._update_target_every = update_target_every - self._soft_update_coef = soft_update_coef + self._q_value_loss_func = param.q_value_loss_cls() if param.q_value_loss_cls is not None else torch.nn.MSELoss() + self._update_target_every = param.update_target_every + self._soft_update_coef = param.soft_update_coef def get_target_action(self, batch: MultiTransitionBatch) -> torch.Tensor: """Get the target policies' actions according to the batch. @@ -278,7 +269,7 @@ def get_non_policy_state(self) -> dict: def set_non_policy_state(self, state: dict) -> None: self.set_critic_state(state) - def to_device(self, device: str) -> None: + def to_device(self, device: str = None) -> None: self._device = get_torch_device(device) if self._policy: self._policy.to_device(self._device) @@ -294,31 +285,51 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer): See https://arxiv.org/abs/1706.02275 for details. """ - def __init__(self, name: str, params: DiscreteMADDPGParams) -> None: - super(DiscreteMADDPGTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: DiscreteMADDPGParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(DiscreteMADDPGTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params - self._ops_params = self._params.extract_ops_params() + self._state_dim = params.get_q_critic_net_func().state_dim self._policy_version = self._target_policy_version = 0 self._shared_critic_ops_name = f"{self._name}.shared_critic" - self._actor_ops_list = [] - self._critic_ops = None - self._replay_memory = None - self._policy2agent = {} + self._actor_ops_list: List[DiscreteMADDPGOps] = [] + self._critic_ops: Optional[DiscreteMADDPGOps] = None + self._policy2agent: Dict[str, str] = {} + self._ops_dict: Dict[str, DiscreteMADDPGOps] = {} def build(self) -> None: + self._placeholder_policy = self._policy_dict[self._policy_names[0]] + for policy in self._policy_dict.values(): - self._ops_dict[policy.name] = self.get_ops(policy.name) + self._ops_dict[policy.name] = cast(DiscreteMADDPGOps, self.get_ops(policy.name)) self._actor_ops_list = list(self._ops_dict.values()) if self._params.shared_critic: - self._ops_dict[self._shared_critic_ops_name] = self.get_ops(self._shared_critic_ops_name) + assert self._critic_ops is not None + self._ops_dict[self._shared_critic_ops_name] = cast( + DiscreteMADDPGOps, + self.get_ops(self._shared_critic_ops_name), + ) self._critic_ops = self._ops_dict[self._shared_critic_ops_name] self._replay_memory = RandomMultiReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._state_dim, action_dims=[ops.policy_action_dim for ops in self._actor_ops_list], agent_states_dims=[ops.policy_state_dim for ops in self._actor_ops_list], @@ -372,23 +383,25 @@ def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: def get_local_ops(self, name: str) -> AbsTrainOps: if name == self._shared_critic_ops_name: - ops_params = dict(self._ops_params) - ops_params.update( - { - "policy_idx": -1, - "shared_critic": False, - }, + return DiscreteMADDPGOps( + name=name, + policy=self._placeholder_policy, + param=self._params, + shared_critic=False, + policy_idx=-1, + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, ) - return DiscreteMADDPGOps(name=name, **ops_params) else: - ops_params = dict(self._ops_params) - ops_params.update( - { - "policy": self._policy_dict[name], - "policy_idx": self._policy_names.index(name), - }, + return DiscreteMADDPGOps( + name=name, + policy=self._policy_dict[name], + param=self._params, + shared_critic=self._params.shared_critic, + policy_idx=self._policy_names.index(name), + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, ) - return DiscreteMADDPGOps(name=name, **ops_params) def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch: return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size) @@ -403,6 +416,7 @@ def train_step(self) -> None: # Update critic if self._params.shared_critic: + assert self._critic_ops is not None self._critic_ops.update_critic(batch, next_actions) critic_state_dict = self._critic_ops.get_critic_state() # Sync latest critic to ops @@ -429,6 +443,7 @@ async def train_step_as_task(self) -> None: # Update critic if self._params.shared_critic: + assert self._critic_ops is not None critic_grad = await asyncio.gather(*[self._critic_ops.get_critic_grad(batch, next_actions)]) assert isinstance(critic_grad, list) and isinstance(critic_grad[0], dict) self._critic_ops.update_critic_with_grad(critic_grad[0]) @@ -458,10 +473,11 @@ def _try_soft_update_target(self) -> None: for ops in self._actor_ops_list: ops.soft_update_target() if self._params.shared_critic: + assert self._critic_ops is not None self._critic_ops.soft_update_target() self._target_policy_version = self._policy_version - def get_policy_state(self) -> Dict[str, object]: + def get_policy_state(self) -> Dict[str, dict]: self._assert_ops_exists() ret_policy_state = {} for ops in self._actor_ops_list: @@ -482,6 +498,7 @@ def save(self, path: str) -> None: trainer_state = {ops.name: ops.get_state() for ops in self._actor_ops_list} if self._params.shared_critic: + assert self._critic_ops is not None trainer_state[self._critic_ops.name] = self._critic_ops.get_state() policy_state_dict = {ops_name: state["policy"] for ops_name, state in trainer_state.items()} diff --git a/maro/rl/training/algorithms/ppo.py b/maro/rl/training/algorithms/ppo.py index bbbdc3adc..7abe089ef 100644 --- a/maro/rl/training/algorithms/ppo.py +++ b/maro/rl/training/algorithms/ppo.py @@ -2,13 +2,12 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Callable, Dict, Tuple +from typing import Tuple import numpy as np import torch from torch.distributions import Categorical -from maro.rl.model import VNet from maro.rl.policy import DiscretePolicyGradient, RLPolicy from maro.rl.training.algorithms.base import ACBasedOps, ACBasedParams, ACBasedTrainer from maro.rl.utils import TransitionBatch, discount_cumsum, ndarray_to_tensor @@ -24,21 +23,7 @@ class PPOParams(ACBasedParams): If it is None, the actor loss is calculated using the usual policy gradient theorem. """ - clip_ratio: float = None - - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_v_critic_net_func": self.get_v_critic_net_func, - "reward_discount": self.reward_discount, - "critic_loss_cls": self.critic_loss_cls, - "clip_ratio": self.clip_ratio, - "lam": self.lam, - "min_logp": self.min_logp, - "is_discrete_action": self.is_discrete_action, - } - def __post_init__(self) -> None: - assert self.get_v_critic_net_func is not None assert self.clip_ratio is not None @@ -47,30 +32,19 @@ def __init__( self, name: str, policy: RLPolicy, - get_v_critic_net_func: Callable[[], VNet], + params: ACBasedParams, parallelism: int = 1, reward_discount: float = 0.9, - critic_loss_cls: Callable = None, - clip_ratio: float = None, - lam: float = 0.9, - min_logp: float = None, - is_discrete_action: bool = True, ) -> None: super(DiscretePPOWithEntropyOps, self).__init__( - name=name, - policy=policy, - get_v_critic_net_func=get_v_critic_net_func, - parallelism=parallelism, - reward_discount=reward_discount, - critic_loss_cls=critic_loss_cls, - clip_ratio=clip_ratio, - lam=lam, - min_logp=min_logp, - is_discrete_action=is_discrete_action, + name, + policy, + params, + reward_discount, + parallelism, ) - assert is_discrete_action - assert isinstance(self._policy, DiscretePolicyGradient) - self._policy_old = clone(policy) + assert self._is_discrete_action + self._policy_old: DiscretePolicyGradient = clone(policy) self.update_policy_old() def update_policy_old(self) -> None: @@ -173,8 +147,23 @@ class PPOTrainer(ACBasedTrainer): https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ppo. """ - def __init__(self, name: str, params: PPOParams) -> None: - super(PPOTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: PPOParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(PPOTrainer, self).__init__( + name, + params, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) class DiscretePPOWithEntropyTrainer(ACBasedTrainer): @@ -185,8 +174,9 @@ def get_local_ops(self) -> DiscretePPOWithEntropyOps: return DiscretePPOWithEntropyOps( name=self._policy.name, policy=self._policy, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def train_step(self) -> None: diff --git a/maro/rl/training/algorithms/sac.py b/maro/rl/training/algorithms/sac.py index ce2eee6d2..338addf57 100644 --- a/maro/rl/training/algorithms/sac.py +++ b/maro/rl/training/algorithms/sac.py @@ -2,53 +2,37 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, cast import torch from maro.rl.model import QNet from maro.rl.policy import ContinuousRLPolicy, RLPolicy -from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone @dataclass -class SoftActorCriticParams(TrainerParams): - get_q_critic_net_func: Callable[[], QNet] = None +class SoftActorCriticParams(BaseTrainerParams): + get_q_critic_net_func: Callable[[], QNet] update_target_every: int = 5 random_overwrite: bool = False entropy_coef: float = 0.1 num_epochs: int = 1 n_start_train: int = 0 - q_value_loss_cls: Callable = None + q_value_loss_cls: Optional[Callable] = None soft_update_coef: float = 1.0 - def __post_init__(self) -> None: - assert self.get_q_critic_net_func is not None - - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_q_critic_net_func": self.get_q_critic_net_func, - "entropy_coef": self.entropy_coef, - "reward_discount": self.reward_discount, - "q_value_loss_cls": self.q_value_loss_cls, - "soft_update_coef": self.soft_update_coef, - } - class SoftActorCriticOps(AbsTrainOps): def __init__( self, name: str, policy: RLPolicy, - get_q_critic_net_func: Callable[[], QNet], + params: SoftActorCriticParams, + reward_discount: float = 0.9, parallelism: int = 1, - *, - entropy_coef: float, - reward_discount: float, - q_value_loss_cls: Callable = None, - soft_update_coef: float = 1.0, ) -> None: super(SoftActorCriticOps, self).__init__( name=name, @@ -58,17 +42,19 @@ def __init__( assert isinstance(self._policy, ContinuousRLPolicy) - self._q_net1 = get_q_critic_net_func() - self._q_net2 = get_q_critic_net_func() + self._q_net1 = params.get_q_critic_net_func() + self._q_net2 = params.get_q_critic_net_func() self._target_q_net1: QNet = clone(self._q_net1) self._target_q_net1.eval() self._target_q_net2: QNet = clone(self._q_net2) self._target_q_net2.eval() - self._entropy_coef = entropy_coef - self._soft_update_coef = soft_update_coef + self._entropy_coef = params.entropy_coef + self._soft_update_coef = params.soft_update_coef self._reward_discount = reward_discount - self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() + self._q_value_loss_func = ( + params.q_value_loss_cls() if params.q_value_loss_cls is not None else torch.nn.MSELoss() + ) def _get_critic_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, torch.Tensor]: self._q_net1.train() @@ -100,11 +86,11 @@ def get_critic_grad(self, batch: TransitionBatch) -> Tuple[Dict[str, torch.Tenso grad_q2 = self._q_net2.get_gradients(loss_q2) return grad_q1, grad_q2 - def update_critic_with_grad(self, grad_dict1: dict, grad_dict2: dict) -> None: + def update_critic_with_grad(self, grad_dicts: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]) -> None: self._q_net1.train() self._q_net2.train() - self._q_net1.apply_gradients(grad_dict1) - self._q_net2.apply_gradients(grad_dict2) + self._q_net1.apply_gradients(grad_dicts[0]) + self._q_net2.apply_gradients(grad_dicts[1]) def update_critic(self, batch: TransitionBatch) -> None: self._q_net1.train() @@ -154,7 +140,7 @@ def soft_update_target(self) -> None: self._target_q_net1.soft_update(self._q_net1, self._soft_update_coef) self._target_q_net2.soft_update(self._q_net2, self._soft_update_coef) - def to_device(self, device: str) -> None: + def to_device(self, device: str = None) -> None: self._device = get_torch_device(device=device) self._q_net1.to(self._device) self._q_net2.to(self._device) @@ -163,22 +149,38 @@ def to_device(self, device: str) -> None: class SoftActorCriticTrainer(SingleAgentTrainer): - def __init__(self, name: str, params: SoftActorCriticParams) -> None: - super(SoftActorCriticTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: SoftActorCriticParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(SoftActorCriticTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params self._qnet_version = self._target_qnet_version = 0 - self._replay_memory: Optional[RandomReplayMemory] = None - def build(self) -> None: - self._ops = self.get_ops() + self._ops = cast(SoftActorCriticOps, self.get_ops()) self._replay_memory = RandomReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._ops.policy_state_dim, action_dim=self._ops.policy_action_dim, random_overwrite=self._params.random_overwrite, ) + def _register_policy(self, policy: RLPolicy) -> None: + assert isinstance(policy, ContinuousRLPolicy) + self._policy = policy + def train_step(self) -> None: assert isinstance(self._ops, SoftActorCriticOps) @@ -220,8 +222,9 @@ def get_local_ops(self) -> SoftActorCriticOps: return SoftActorCriticOps( name=self._policy.name, policy=self._policy, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def _get_batch(self, batch_size: int = None) -> TransitionBatch: diff --git a/maro/rl/training/proxy.py b/maro/rl/training/proxy.py index 29eaaed7a..04f1af849 100644 --- a/maro/rl/training/proxy.py +++ b/maro/rl/training/proxy.py @@ -2,8 +2,9 @@ # Licensed under the MIT license. from collections import defaultdict, deque +from typing import Deque -from maro.rl.distributed import AbsProxy +from maro.rl.distributed import DEFAULT_TRAINING_BACKEND_PORT, DEFAULT_TRAINING_FRONTEND_PORT, AbsProxy from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes from maro.rl.utils.torch_utils import average_grads from maro.utils import LoggerV2 @@ -20,13 +21,16 @@ class TrainingProxy(AbsProxy): backend_port (int, default=10001): Network port for communicating with back-end workers (task consumers). """ - def __init__(self, frontend_port: int = 10000, backend_port: int = 10001) -> None: - super(TrainingProxy, self).__init__(frontend_port=frontend_port, backend_port=backend_port) - self._available_workers = deque() - self._worker_ready = False - self._connected_ops = set() - self._result_cache = defaultdict(list) - self._expected_num_results = {} + def __init__(self, frontend_port: int = None, backend_port: int = None) -> None: + super(TrainingProxy, self).__init__( + frontend_port=frontend_port if frontend_port is not None else DEFAULT_TRAINING_FRONTEND_PORT, + backend_port=backend_port if backend_port is not None else DEFAULT_TRAINING_BACKEND_PORT, + ) + self._available_workers: Deque = deque() + self._worker_ready: bool = False + self._connected_ops: set = set() + self._result_cache: dict = defaultdict(list) + self._expected_num_results: dict = {} self._logger = LoggerV2("TRAIN-PROXY") def _route_request_to_compute_node(self, msg: list) -> None: @@ -48,10 +52,12 @@ def _route_request_to_compute_node(self, msg: list) -> None: self._connected_ops.add(msg[0]) req = bytes_to_pyobj(msg[-1]) + assert isinstance(req, dict) + desired_parallelism = req["desired_parallelism"] req["args"] = list(req["args"]) batch = req["args"][0] - workers = [] + workers: list = [] while len(workers) < desired_parallelism and self._available_workers: workers.append(self._available_workers.popleft()) diff --git a/maro/rl/training/train_ops.py b/maro/rl/training/train_ops.py index 934d965e8..57888038a 100644 --- a/maro/rl/training/train_ops.py +++ b/maro/rl/training/train_ops.py @@ -3,8 +3,9 @@ import inspect from abc import ABCMeta, abstractmethod -from typing import Callable, Tuple +from typing import Any, Callable, Optional, Tuple, Union +import torch import zmq from zmq.asyncio import Context, Poller @@ -32,10 +33,8 @@ def __init__( super(AbsTrainOps, self).__init__() self._name = name self._policy = policy - self._parallelism = parallelism - - self._device = None + self._device: Optional[torch.device] = None @property def name(self) -> str: @@ -43,11 +42,11 @@ def name(self) -> str: @property def policy_state_dim(self) -> int: - return self._policy.state_dim if self._policy else None + return self._policy.state_dim @property def policy_action_dim(self) -> int: - return self._policy.action_dim if self._policy else None + return self._policy.action_dim @property def parallelism(self) -> int: @@ -74,12 +73,12 @@ def set_state(self, ops_state_dict: dict) -> None: self.set_policy_state(ops_state_dict["policy"][1]) self.set_non_policy_state(ops_state_dict["non_policy"]) - def get_policy_state(self) -> Tuple[str, object]: + def get_policy_state(self) -> Tuple[str, dict]: """Get the policy's state. Returns: policy_name (str) - policy_state (object) + policy_state (Any) """ return self._policy.name, self._policy.get_state() @@ -110,17 +109,17 @@ def set_non_policy_state(self, state: dict) -> None: raise NotImplementedError @abstractmethod - def to_device(self, device: str): + def to_device(self, device: str = None) -> None: raise NotImplementedError -def remote(func) -> Callable: +def remote(func: Callable) -> Callable: """Annotation to indicate that a function / method can be called remotely. This annotation takes effect only when an ``AbsTrainOps`` object is wrapped by a ``RemoteOps``. """ - def remote_annotate(*args, **kwargs) -> object: + def remote_annotate(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) return remote_annotate @@ -136,7 +135,7 @@ class AsyncClient(object): """ def __init__(self, name: str, address: Tuple[str, int], logger: LoggerV2 = None) -> None: - self._logger = DummyLogger() if logger is None else logger + self._logger: Union[LoggerV2, DummyLogger] = logger if logger is not None else DummyLogger() self._name = name host, port = address self._proxy_ip = get_ip_address_by_hostname(host) @@ -154,7 +153,7 @@ async def send_request(self, req: dict) -> None: await self._socket.send(pyobj_to_bytes(req)) self._logger.debug(f"{self._name} sent request {req['func']}") - async def get_response(self) -> object: + async def get_response(self) -> Any: """Waits for a result in asynchronous fashion. This is a coroutine and is executed asynchronously with calls to other AsyncClients' ``get_response`` calls. @@ -208,15 +207,15 @@ def __init__(self, ops: AbsTrainOps, address: Tuple[str, int], logger: LoggerV2 self._client = AsyncClient(self._ops.name, address, logger=logger) self._client.connect() - def __getattribute__(self, attr_name: str) -> object: + def __getattribute__(self, attr_name: str) -> Any: # Ignore methods that belong to the parent class try: return super().__getattribute__(attr_name) except AttributeError: pass - def remote_method(ops_state, func_name: str, desired_parallelism: int, client: AsyncClient) -> Callable: - async def remote_call(*args, **kwargs) -> object: + def remote_method(ops_state: Any, func_name: str, desired_parallelism: int, client: AsyncClient) -> Callable: + async def remote_call(*args: Any, **kwargs: Any) -> Any: req = { "state": ops_state, "func": func_name, diff --git a/maro/rl/training/trainer.py b/maro/rl/training/trainer.py index a1ea52096..8bced5674 100644 --- a/maro/rl/training/trainer.py +++ b/maro/rl/training/trainer.py @@ -21,37 +21,8 @@ @dataclass -class TrainerParams: - """Common trainer parameters. - - replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory. - batch_size (int, default=128): Training batch size. - data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when - a model is large and computing gradients with respect to a batch becomes expensive. In this case, the - batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set - of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets - updated only after collecting all the gradients from the remote nodes. Note that this value is the desired - parallelism and the actual parallelism in a distributed experiment may be smaller depending on the - availability of compute resources. For details on distributed deep learning and data parallelism, see - https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an abundance - of resources available on the internet. - reward_discount (float, default=0.9): Reward decay as defined in standard RL terminology. - - """ - - replay_memory_capacity: int = 10000 - batch_size: int = 128 - data_parallelism: int = 1 - reward_discount: float = 0.9 - - @abstractmethod - def extract_ops_params(self) -> Dict[str, object]: - """Extract parameters that should be passed to the train ops. - - Returns: - params (Dict[str, object]): Parameter dict. - """ - raise NotImplementedError +class BaseTrainerParams: + pass class AbsTrainer(object, metaclass=ABCMeta): @@ -64,16 +35,36 @@ class AbsTrainer(object, metaclass=ABCMeta): Args: name (str): Name of the trainer. - params (TrainerParams): Trainer's parameters. + replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory. + batch_size (int, default=128): Training batch size. + data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when + a model is large and computing gradients with respect to a batch becomes expensive. In this case, the + batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set + of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets + updated only after collecting all the gradients from the remote nodes. Note that this value is the desired + parallelism and the actual parallelism in a distributed experiment may be smaller depending on the + availability of compute resources. For details on distributed deep learning and data parallelism, see + https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an + abundance of resources available on the internet. + reward_discount (float, default=0.9): Reward decay as defined in standard RL terminology. """ - def __init__(self, name: str, params: TrainerParams) -> None: + def __init__( + self, + name: str, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: self._name = name - self._params = params - self._batch_size = self._params.batch_size + self._replay_memory_capacity = replay_memory_capacity + self._batch_size = batch_size + self._data_parallelism = data_parallelism + self._reward_discount = reward_discount + self._agent2policy: Dict[Any, str] = {} self._proxy_address: Optional[Tuple[str, int]] = None - self._logger = None @property def name(self) -> str: @@ -83,7 +74,7 @@ def name(self) -> str: def agent_num(self) -> int: return len(self._agent2policy) - def register_logger(self, logger: LoggerV2) -> None: + def register_logger(self, logger: LoggerV2 = None) -> None: self._logger = logger def register_agent2policy(self, agent2policy: Dict[Any, str], policy_trainer_mapping: Dict[str, str]) -> None: @@ -140,7 +131,7 @@ def set_proxy_address(self, proxy_address: Tuple[str, int]) -> None: self._proxy_address = proxy_address @abstractmethod - def get_policy_state(self) -> Dict[str, object]: + def get_policy_state(self) -> Dict[str, dict]: """Get policies' states. Returns: @@ -164,22 +155,46 @@ async def exit(self) -> None: class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta): """Policy trainer that trains only one policy.""" - def __init__(self, name: str, params: TrainerParams) -> None: - super(SingleAgentTrainer, self).__init__(name, params) - self._policy: Optional[RLPolicy] = None - self._ops: Optional[AbsTrainOps] = None - self._replay_memory: Optional[ReplayMemory] = None + def __init__( + self, + name: str, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(SingleAgentTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) + + @property + def ops(self) -> Union[AbsTrainOps, RemoteOps]: + ops = getattr(self, "_ops", None) + assert isinstance(ops, (AbsTrainOps, RemoteOps)) + return ops @property - def ops(self): - return self._ops + def replay_memory(self) -> ReplayMemory: + replay_memory = getattr(self, "_replay_memory", None) + assert isinstance(replay_memory, ReplayMemory), "Replay memory is required." + return replay_memory def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None: policies = [policy for policy in policies if policy_trainer_mapping[policy.name] == self.name] if len(policies) != 1: raise ValueError(f"Trainer {self._name} should have exactly one policy assigned to it") - self._policy = policies.pop() + policy = policies.pop() + assert isinstance(policy, RLPolicy) + self._register_policy(policy) + + @abstractmethod + def _register_policy(self, policy: RLPolicy) -> None: + raise NotImplementedError @abstractmethod def get_local_ops(self) -> AbsTrainOps: @@ -201,9 +216,9 @@ def get_ops(self) -> Union[RemoteOps, AbsTrainOps]: ops = self.get_local_ops() return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops - def get_policy_state(self) -> Dict[str, object]: + def get_policy_state(self) -> Dict[str, dict]: self._assert_ops_exists() - policy_name, state = self._ops.get_policy_state() + policy_name, state = self.ops.get_policy_state() return {policy_name: state} def load(self, path: str) -> None: @@ -212,7 +227,7 @@ def load(self, path: str) -> None: policy_state = torch.load(os.path.join(path, f"{self.name}_policy.{FILE_SUFFIX}")) non_policy_state = torch.load(os.path.join(path, f"{self.name}_non_policy.{FILE_SUFFIX}")) - self._ops.set_state( + self.ops.set_state( { "policy": policy_state, "non_policy": non_policy_state, @@ -222,7 +237,7 @@ def load(self, path: str) -> None: def save(self, path: str) -> None: self._assert_ops_exists() - ops_state = self._ops.get_state() + ops_state = self.ops.get_state() policy_state = ops_state["policy"] non_policy_state = ops_state["non_policy"] @@ -252,40 +267,57 @@ def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: next_states=np.vstack([exp[4] for exp in exps]), ) transition_batch = self._preprocess_batch(transition_batch) - self._replay_memory.put(transition_batch) + self.replay_memory.put(transition_batch) @abstractmethod def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: raise NotImplementedError def _assert_ops_exists(self) -> None: - if not self._ops: + if not self.ops: raise ValueError("'build' needs to be called to create an ops instance first.") async def exit(self) -> None: self._assert_ops_exists() - if isinstance(self._ops, RemoteOps): - await self._ops.exit() + ops = self.ops + if isinstance(ops, RemoteOps): + await ops.exit() class MultiAgentTrainer(AbsTrainer, metaclass=ABCMeta): """Policy trainer that trains multiple policies.""" - def __init__(self, name: str, params: TrainerParams) -> None: - super(MultiAgentTrainer, self).__init__(name, params) - self._policy_names: List[str] = [] - self._policy_dict: Dict[str, RLPolicy] = {} - self._ops_dict: Dict[str, AbsTrainOps] = {} + def __init__( + self, + name: str, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(MultiAgentTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) @property - def ops_dict(self): - return self._ops_dict + def ops_dict(self) -> Dict[str, AbsTrainOps]: + ops_dict = getattr(self, "_ops_dict", None) + assert isinstance(ops_dict, dict) + return ops_dict def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None: - self._policy_names = [policy.name for policy in policies if policy_trainer_mapping[policy.name] == self.name] - self._policy_dict = { - policy.name: policy for policy in policies if policy_trainer_mapping[policy.name] == self.name - } + self._policy_names: List[str] = [ + policy.name for policy in policies if policy_trainer_mapping[policy.name] == self.name + ] + self._policy_dict: Dict[str, RLPolicy] = {} + for policy in policies: + if policy_trainer_mapping[policy.name] == self.name: + assert isinstance(policy, RLPolicy) + self._policy_dict[policy.name] = policy @abstractmethod def get_local_ops(self, name: str) -> AbsTrainOps: @@ -314,7 +346,7 @@ def get_ops(self, name: str) -> Union[RemoteOps, AbsTrainOps]: return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops @abstractmethod - def get_policy_state(self) -> Dict[str, object]: + def get_policy_state(self) -> Dict[str, dict]: raise NotImplementedError @abstractmethod diff --git a/maro/rl/training/training_manager.py b/maro/rl/training/training_manager.py index 56e67d63b..9d6b36b15 100644 --- a/maro/rl/training/training_manager.py +++ b/maro/rl/training/training_manager.py @@ -7,7 +7,6 @@ import collections import os import typing -from itertools import chain from typing import Any, Dict, Iterable, List, Tuple from maro.rl.rollout import ExpElement @@ -92,13 +91,16 @@ async def train_step() -> Iterable: for trainer in self._trainer_dict.values(): trainer.train_step() - def get_policy_state(self) -> Dict[str, Dict[str, object]]: + def get_policy_state(self) -> Dict[str, dict]: """Get policies' states. Returns: A double-deck dict with format: {trainer_name: {policy_name: policy_state}} """ - return dict(chain(*[trainer.get_policy_state().items() for trainer in self._trainer_dict.values()])) + policy_states: Dict[str, dict] = {} + for trainer in self._trainer_dict.values(): + policy_states.update(trainer.get_policy_state()) + return policy_states def record_experiences(self, experiences: List[List[ExpElement]]) -> None: """Record experiences collected from external modules (for example, EnvSampler). diff --git a/maro/rl/training/worker.py b/maro/rl/training/worker.py index 4e30e816f..4cb1528f4 100644 --- a/maro/rl/training/worker.py +++ b/maro/rl/training/worker.py @@ -6,7 +6,7 @@ import typing from typing import Dict -from maro.rl.distributed import AbsWorker +from maro.rl.distributed import DEFAULT_TRAINING_BACKEND_PORT, AbsWorker from maro.rl.training import SingleAgentTrainer from maro.rl.utils.common import bytes_to_pyobj, bytes_to_string, pyobj_to_bytes from maro.utils import LoggerV2 @@ -34,13 +34,13 @@ def __init__( idx: int, rl_component_bundle: RLComponentBundle, producer_host: str, - producer_port: int = 10001, + producer_port: int = None, logger: LoggerV2 = None, ) -> None: super(TrainOpsWorker, self).__init__( idx=idx, producer_host=producer_host, - producer_port=producer_port, + producer_port=producer_port if producer_port is not None else DEFAULT_TRAINING_BACKEND_PORT, logger=logger, ) diff --git a/maro/rl/utils/common.py b/maro/rl/utils/common.py index e69b907b7..516239670 100644 --- a/maro/rl/utils/common.py +++ b/maro/rl/utils/common.py @@ -4,17 +4,17 @@ import os import pickle import socket -from typing import List, Optional +from typing import Any, List, Optional -def get_env(var_name: str, required: bool = True, default: object = None) -> str: +def get_env(var_name: str, required: bool = True, default: str = None) -> Optional[str]: """Wrapper for os.getenv() that includes a check for mandatory environment variables. Args: var_name (str): Variable name. required (bool, default=True): Flag indicating whether the environment variable in questions is required. If this is true and the environment variable is not present in ``os.environ``, a ``KeyError`` is raised. - default (object, default=None): Default value for the environment variable if it is missing in ``os.environ`` + default (str, default=None): Default value for the environment variable if it is missing in ``os.environ`` and ``required`` is false. Ignored if ``required`` is True. Returns: @@ -52,11 +52,11 @@ def bytes_to_string(bytes_: bytes) -> str: return bytes_.decode(DEFAULT_MSG_ENCODING) -def pyobj_to_bytes(pyobj) -> bytes: +def pyobj_to_bytes(pyobj: Any) -> bytes: return pickle.dumps(pyobj) -def bytes_to_pyobj(bytes_: bytes) -> object: +def bytes_to_pyobj(bytes_: bytes) -> Any: return pickle.loads(bytes_) diff --git a/maro/rl/utils/torch_utils.py b/maro/rl/utils/torch_utils.py index 3335fe24a..82476411f 100644 --- a/maro/rl/utils/torch_utils.py +++ b/maro/rl/utils/torch_utils.py @@ -55,5 +55,5 @@ def average_grads(grad_list: List[dict]) -> dict: } -def get_torch_device(device: str = None): +def get_torch_device(device: str = None) -> torch.device: return torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu")) diff --git a/maro/rl/workflows/config/parser.py b/maro/rl/workflows/config/parser.py index 68e8cbb30..db52f065a 100644 --- a/maro/rl/workflows/config/parser.py +++ b/maro/rl/workflows/config/parser.py @@ -207,7 +207,7 @@ def _validate_checkpointing_section(self, section: dict) -> None: f"{self._validation_err_pfx}: 'training.checkpointing.interval' must be an int", ) - def _validate_logging_section(self, component, level_dict: dict) -> None: + def _validate_logging_section(self, component: str, level_dict: dict) -> None: if any(key not in {"stdout", "file"} for key in level_dict): raise KeyError( f"{self._validation_err_pfx}: fields under section '{component}.logging' must be 'stdout' or 'file'", @@ -261,7 +261,7 @@ def get_job_spec(self, containerize: bool = False) -> Dict[str, Tuple[str, Dict[ num_episodes = self._config["main"]["num_episodes"] main_proc = f"{self._config['job']}.main" min_n_sample = self._config["main"].get("min_n_sample", 1) - env = { + env: dict = { main_proc: ( os.path.join(self._get_workflow_path(containerize=containerize), "main.py"), { diff --git a/maro/rl/workflows/main.py b/maro/rl/workflows/main.py index d558bee67..31de7caa1 100644 --- a/maro/rl/workflows/main.py +++ b/maro/rl/workflows/main.py @@ -14,20 +14,21 @@ from maro.rl.utils import get_torch_device from maro.rl.utils.common import float_or_none, get_env, int_or_none, list_or_none from maro.rl.utils.training import get_latest_ep +from maro.rl.workflows.utils import env_str_helper from maro.utils import LoggerV2 class WorkflowEnvAttributes: def __init__(self) -> None: # Number of training episodes - self.num_episodes = int(get_env("NUM_EPISODES")) + self.num_episodes = int(env_str_helper(get_env("NUM_EPISODES"))) # Maximum number of steps in on round of sampling. self.num_steps = int_or_none(get_env("NUM_STEPS", required=False)) # Minimum number of data samples to start a round of training. If the data samples are insufficient, re-run # data sampling until we have at least `min_n_sample` data entries. - self.min_n_sample = int_or_none(get_env("MIN_N_SAMPLE")) + self.min_n_sample = int(env_str_helper(get_env("MIN_N_SAMPLE"))) # Path to store logs. self.log_path = get_env("LOG_PATH") @@ -57,7 +58,7 @@ def __init__(self) -> None: # Parallel sampling configurations. self.parallel_rollout = self.env_sampling_parallelism is not None or self.env_eval_parallelism is not None if self.parallel_rollout: - self.port = int(get_env("ROLLOUT_CONTROLLER_PORT")) + self.port = int(env_str_helper(get_env("ROLLOUT_CONTROLLER_PORT"))) self.min_env_samples = int_or_none(get_env("MIN_ENV_SAMPLES", required=False)) self.grace_factor = float_or_none(get_env("GRACE_FACTOR", required=False)) @@ -65,7 +66,10 @@ def __init__(self) -> None: # Distributed training configurations. if self.train_mode != "simple": - self.proxy_address = (get_env("TRAIN_PROXY_HOST"), int(get_env("TRAIN_PROXY_FRONTEND_PORT"))) + self.proxy_address = ( + env_str_helper(get_env("TRAIN_PROXY_HOST")), + int(env_str_helper(get_env("TRAIN_PROXY_FRONTEND_PORT"))), + ) self.logger = LoggerV2( "MAIN", @@ -87,7 +91,8 @@ def _get_env_sampler( env_attr: WorkflowEnvAttributes, ) -> Union[AbsEnvSampler, BatchEnvSampler]: if env_attr.parallel_rollout: - env_sampler = BatchEnvSampler( + assert env_attr.env_sampling_parallelism is not None + return BatchEnvSampler( sampling_parallelism=env_attr.env_sampling_parallelism, port=env_attr.port, min_env_samples=env_attr.min_env_samples, @@ -100,8 +105,7 @@ def _get_env_sampler( if rl_component_bundle.device_mapping is not None: for policy_name, device_name in rl_component_bundle.device_mapping.items(): env_sampler.assign_policy_to_device(policy_name, get_torch_device(device_name)) - - return env_sampler + return env_sampler def main(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes, args: argparse.Namespace) -> None: @@ -144,7 +148,7 @@ def training_workflow(rl_component_bundle: RLComponentBundle, env_attr: Workflow # main loop for ep in range(start_ep, env_attr.num_episodes + 1): - collect_time = training_time = 0 + collect_time = training_time = 0.0 total_experiences: List[List[ExpElement]] = [] total_info_list: List[dict] = [] n_sample = 0 @@ -214,7 +218,7 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: Wor if __name__ == "__main__": - scenario_path = get_env("SCENARIO_PATH") + scenario_path = env_str_helper(get_env("SCENARIO_PATH")) scenario_path = os.path.normpath(scenario_path) sys.path.insert(0, os.path.dirname(scenario_path)) module = importlib.import_module(os.path.basename(scenario_path)) diff --git a/maro/rl/workflows/rollout_worker.py b/maro/rl/workflows/rollout_worker.py index 47a5bbec6..8343873b3 100644 --- a/maro/rl/workflows/rollout_worker.py +++ b/maro/rl/workflows/rollout_worker.py @@ -8,17 +8,18 @@ from maro.rl.rl_component.rl_component_bundle import RLComponentBundle from maro.rl.rollout import RolloutWorker from maro.rl.utils.common import get_env, int_or_none +from maro.rl.workflows.utils import env_str_helper from maro.utils import LoggerV2 if __name__ == "__main__": - scenario_path = get_env("SCENARIO_PATH") + scenario_path = env_str_helper(get_env("SCENARIO_PATH")) scenario_path = os.path.normpath(scenario_path) sys.path.insert(0, os.path.dirname(scenario_path)) module = importlib.import_module(os.path.basename(scenario_path)) rl_component_bundle: RLComponentBundle = getattr(module, "rl_component_bundle") - worker_idx = int_or_none(get_env("ID")) + worker_idx = int(env_str_helper(get_env("ID"))) logger = LoggerV2( f"ROLLOUT-WORKER.{worker_idx}", dump_path=get_env("LOG_PATH"), @@ -29,7 +30,7 @@ worker = RolloutWorker( idx=worker_idx, rl_component_bundle=rl_component_bundle, - producer_host=get_env("ROLLOUT_CONTROLLER_HOST"), + producer_host=env_str_helper(get_env("ROLLOUT_CONTROLLER_HOST")), producer_port=int_or_none(get_env("ROLLOUT_CONTROLLER_PORT")), logger=logger, ) diff --git a/maro/rl/workflows/train_worker.py b/maro/rl/workflows/train_worker.py index 011cc0c11..4565c5b72 100644 --- a/maro/rl/workflows/train_worker.py +++ b/maro/rl/workflows/train_worker.py @@ -8,10 +8,11 @@ from maro.rl.rl_component.rl_component_bundle import RLComponentBundle from maro.rl.training import TrainOpsWorker from maro.rl.utils.common import get_env, int_or_none +from maro.rl.workflows.utils import env_str_helper from maro.utils import LoggerV2 if __name__ == "__main__": - scenario_path = get_env("SCENARIO_PATH") + scenario_path = env_str_helper(get_env("SCENARIO_PATH")) scenario_path = os.path.normpath(scenario_path) sys.path.insert(0, os.path.dirname(scenario_path)) module = importlib.import_module(os.path.basename(scenario_path)) @@ -27,9 +28,9 @@ file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"), ) worker = TrainOpsWorker( - idx=int_or_none(get_env("ID")), + idx=int(env_str_helper(get_env("ID"))), rl_component_bundle=rl_component_bundle, - producer_host=get_env("TRAIN_PROXY_HOST"), + producer_host=env_str_helper(get_env("TRAIN_PROXY_HOST")), producer_port=int_or_none(get_env("TRAIN_PROXY_BACKEND_PORT")), logger=logger, ) diff --git a/maro/rl/workflows/utils.py b/maro/rl/workflows/utils.py new file mode 100644 index 000000000..accfbe86f --- /dev/null +++ b/maro/rl/workflows/utils.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Optional + + +def env_str_helper(string: Optional[str]) -> str: + assert string is not None + return string diff --git a/maro/simulator/abs_core.py b/maro/simulator/abs_core.py index b47d94baa..f50ca53cb 100644 --- a/maro/simulator/abs_core.py +++ b/maro/simulator/abs_core.py @@ -72,7 +72,7 @@ def business_engine(self) -> AbsBusinessEngine: return self._business_engine @abstractmethod - def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]: + def step(self, action) -> Tuple[Optional[dict], Optional[list], bool]: """Push the environment to next step with action. Args: diff --git a/maro/simulator/core.py b/maro/simulator/core.py index 5a9524e98..456dbc98e 100644 --- a/maro/simulator/core.py +++ b/maro/simulator/core.py @@ -89,7 +89,7 @@ def __init__( self._streamit_episode = 0 - def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]: + def step(self, action) -> Tuple[Optional[dict], Optional[list], bool]: """Push the environment to next step with action. Args: @@ -267,7 +267,7 @@ def _init_business_engine(self) -> None: additional_options=self._additional_options, ) - def _simulate(self) -> Generator[Tuple[dict, List[object], bool], object, None]: + def _simulate(self) -> Generator[Tuple[dict, list, bool], object, None]: """This is the generator to wrap each episode process.""" self._streamit_episode += 1