From 135d2fc894bf35c1dc28a0c11ce8c8bbb03c8cf2 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 27 Dec 2022 16:28:31 +0800 Subject: [PATCH] Refine rl component bundle (#549) * Config files * Done * Minor bugfix * Add autoflake * Update isort exclude; add pre-commit to requirements * Check only isort * Minor * Format * Test passed * Run pre-commit * Minor bugfix in rl_component_bundle * Pass mypy * Fix a bug in RL notebook * A minor bug fix * Add upper bound for numpy version in test --- examples/cim/rl/__init__.py | 4 +- examples/cim/rl/algorithms/ac.py | 2 +- examples/cim/rl/algorithms/dqn.py | 6 +- examples/cim/rl/algorithms/maddpg.py | 2 +- examples/cim/rl/algorithms/ppo.py | 3 +- examples/cim/rl/config.py | 5 - examples/cim/rl/rl_component_bundle.py | 103 ++-- examples/rl/cim_distributed.yml | 47 ++ examples/vm_scheduling/rl/__init__.py | 4 +- examples/vm_scheduling/rl/algorithms/ac.py | 2 +- examples/vm_scheduling/rl/algorithms/dqn.py | 8 +- examples/vm_scheduling/rl/env_sampler.py | 26 +- .../vm_scheduling/rl/rl_component_bundle.py | 90 ++-- maro/rl/distributed/__init__.py | 4 + maro/rl/distributed/abs_worker.py | 3 +- maro/rl/distributed/port_config.py | 6 + maro/rl/exploration/scheduling.py | 5 +- maro/rl/model/abs_net.py | 20 +- maro/rl/model/fc_block.py | 8 +- maro/rl/policy/abs_policy.py | 16 +- maro/rl/policy/continuous_rl_policy.py | 6 +- maro/rl/policy/discrete_rl_policy.py | 10 +- maro/rl/rl_component/rl_component_bundle.py | 259 ++++------ maro/rl/rollout/batch_env_sampler.py | 42 +- maro/rl/rollout/env_sampler.py | 130 ++--- maro/rl/rollout/worker.py | 33 +- maro/rl/training/__init__.py | 4 +- maro/rl/training/algorithms/ac.py | 32 +- .../training/algorithms/base/ac_ppo_base.py | 82 ++-- maro/rl/training/algorithms/ddpg.py | 78 +-- maro/rl/training/algorithms/dqn.py | 69 +-- maro/rl/training/algorithms/maddpg.py | 143 +++--- maro/rl/training/algorithms/ppo.py | 75 ++- maro/rl/training/algorithms/sac.py | 91 ++-- maro/rl/training/proxy.py | 24 +- maro/rl/training/train_ops.py | 42 +- maro/rl/training/trainer.py | 205 ++++---- maro/rl/training/training_manager.py | 37 +- maro/rl/training/worker.py | 22 +- maro/rl/utils/common.py | 10 +- maro/rl/utils/torch_utils.py | 2 +- maro/rl/workflows/config/parser.py | 4 +- maro/rl/workflows/main.py | 242 +++++----- maro/rl/workflows/rollout_worker.py | 11 +- maro/rl/workflows/train_worker.py | 11 +- maro/rl/workflows/utils.py | 9 + maro/simulator/abs_core.py | 2 +- .../rl_formulation.ipynb | 452 ++++++++++++++++++ tests/requirements.test.txt | 2 +- tests/test_env.py | 2 +- 50 files changed, 1495 insertions(+), 1000 deletions(-) create mode 100644 examples/rl/cim_distributed.yml create mode 100644 maro/rl/distributed/port_config.py create mode 100644 maro/rl/workflows/utils.py create mode 100644 notebooks/container_inventory_management/rl_formulation.ipynb diff --git a/examples/cim/rl/__init__.py b/examples/cim/rl/__init__.py index 695d90ede..90be439f0 100644 --- a/examples/cim/rl/__init__.py +++ b/examples/cim/rl/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .rl_component_bundle import CIMBundle as rl_component_bundle_cls +from .rl_component_bundle import rl_component_bundle __all__ = [ - "rl_component_bundle_cls", + "rl_component_bundle", ] diff --git a/examples/cim/rl/algorithms/ac.py b/examples/cim/rl/algorithms/ac.py index 69bbc49c5..1769493df 100644 --- a/examples/cim/rl/algorithms/ac.py +++ b/examples/cim/rl/algorithms/ac.py @@ -54,9 +54,9 @@ def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyG def get_ac(state_dim: int, name: str) -> ActorCriticTrainer: return ActorCriticTrainer( name=name, + reward_discount=0.0, params=ActorCriticParams( get_v_critic_net_func=lambda: MyCriticNet(state_dim), - reward_discount=0.0, grad_iters=10, critic_loss_cls=torch.nn.SmoothL1Loss, min_logp=None, diff --git a/examples/cim/rl/algorithms/dqn.py b/examples/cim/rl/algorithms/dqn.py index c5999424a..d62e3443d 100644 --- a/examples/cim/rl/algorithms/dqn.py +++ b/examples/cim/rl/algorithms/dqn.py @@ -55,14 +55,14 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli def get_dqn(name: str) -> DQNTrainer: return DQNTrainer( name=name, + reward_discount=0.0, + replay_memory_capacity=10000, + batch_size=32, params=DQNParams( - reward_discount=0.0, update_target_every=5, num_epochs=10, soft_update_coef=0.1, double=False, - replay_memory_capacity=10000, random_overwrite=False, - batch_size=32, ), ) diff --git a/examples/cim/rl/algorithms/maddpg.py b/examples/cim/rl/algorithms/maddpg.py index 7d964f6bb..e6fd0a65b 100644 --- a/examples/cim/rl/algorithms/maddpg.py +++ b/examples/cim/rl/algorithms/maddpg.py @@ -62,8 +62,8 @@ def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePol def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer: return DiscreteMADDPGTrainer( name=name, + reward_discount=0.0, params=DiscreteMADDPGParams( - reward_discount=0.0, num_epoch=10, get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims), shared_critic=False, diff --git a/examples/cim/rl/algorithms/ppo.py b/examples/cim/rl/algorithms/ppo.py index d2e2df0d9..f18a08750 100644 --- a/examples/cim/rl/algorithms/ppo.py +++ b/examples/cim/rl/algorithms/ppo.py @@ -16,12 +16,11 @@ def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicy def get_ppo(state_dim: int, name: str) -> PPOTrainer: return PPOTrainer( name=name, + reward_discount=0.0, params=PPOParams( get_v_critic_net_func=lambda: MyCriticNet(state_dim), - reward_discount=0.0, grad_iters=10, critic_loss_cls=torch.nn.SmoothL1Loss, - min_logp=None, lam=0.0, clip_ratio=0.1, ), diff --git a/examples/cim/rl/config.py b/examples/cim/rl/config.py index 78d943577..a46194900 100644 --- a/examples/cim/rl/config.py +++ b/examples/cim/rl/config.py @@ -7,11 +7,6 @@ "durations": 560, } -if env_conf["topology"].startswith("toy"): - num_agents = int(env_conf["topology"].split(".")[1][0]) -else: - num_agents = int(env_conf["topology"].split(".")[1][:2]) - port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"] vessel_attributes = ["empty", "full", "remaining_space"] diff --git a/examples/cim/rl/rl_component_bundle.py b/examples/cim/rl/rl_component_bundle.py index 5b16aed97..d290c8f1d 100644 --- a/examples/cim/rl/rl_component_bundle.py +++ b/examples/cim/rl/rl_component_bundle.py @@ -1,77 +1,48 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from functools import partial -from typing import Any, Callable, Dict, Optional - -from maro.rl.policy import AbsPolicy from maro.rl.rl_component.rl_component_bundle import RLComponentBundle -from maro.rl.rollout import AbsEnvSampler -from maro.rl.training import AbsTrainer +from maro.simulator import Env from .algorithms.ac import get_ac, get_ac_policy from .algorithms.dqn import get_dqn, get_dqn_policy from .algorithms.maddpg import get_maddpg, get_maddpg_policy from .algorithms.ppo import get_ppo, get_ppo_policy -from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim +from examples.cim.rl.config import action_num, algorithm, env_conf, reward_shaping_conf, state_dim from examples.cim.rl.env_sampler import CIMEnvSampler - -class CIMBundle(RLComponentBundle): - def get_env_config(self) -> dict: - return env_conf - - def get_test_env_config(self) -> Optional[dict]: - return None - - def get_env_sampler(self) -> AbsEnvSampler: - return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"]) - - def get_agent2policy(self) -> Dict[Any, str]: - return {agent: f"{algorithm}_{agent}.policy" for agent in self.env.agent_idx_list} - - def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]: - if algorithm == "ac": - policy_creator = { - f"{algorithm}_{i}.policy": partial(get_ac_policy, state_dim, action_num, f"{algorithm}_{i}.policy") - for i in range(num_agents) - } - elif algorithm == "ppo": - policy_creator = { - f"{algorithm}_{i}.policy": partial(get_ppo_policy, state_dim, action_num, f"{algorithm}_{i}.policy") - for i in range(num_agents) - } - elif algorithm == "dqn": - policy_creator = { - f"{algorithm}_{i}.policy": partial(get_dqn_policy, state_dim, action_num, f"{algorithm}_{i}.policy") - for i in range(num_agents) - } - elif algorithm == "discrete_maddpg": - policy_creator = { - f"{algorithm}_{i}.policy": partial(get_maddpg_policy, state_dim, action_num, f"{algorithm}_{i}.policy") - for i in range(num_agents) - } - else: - raise ValueError(f"Unsupported algorithm: {algorithm}") - - return policy_creator - - def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]: - if algorithm == "ac": - trainer_creator = { - f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}") for i in range(num_agents) - } - elif algorithm == "ppo": - trainer_creator = { - f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}") for i in range(num_agents) - } - elif algorithm == "dqn": - trainer_creator = {f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}") for i in range(num_agents)} - elif algorithm == "discrete_maddpg": - trainer_creator = { - f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents) - } - else: - raise ValueError(f"Unsupported algorithm: {algorithm}") - - return trainer_creator +# Environments +learn_env = Env(**env_conf) +test_env = learn_env + +# Agent, policy, and trainers +num_agents = len(learn_env.agent_idx_list) +agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} +if algorithm == "ac": + policies = [get_ac_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)] + trainers = [get_ac(state_dim, f"{algorithm}_{i}") for i in range(num_agents)] +elif algorithm == "ppo": + policies = [get_ppo_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)] + trainers = [get_ppo(state_dim, f"{algorithm}_{i}") for i in range(num_agents)] +elif algorithm == "dqn": + policies = [get_dqn_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)] + trainers = [get_dqn(f"{algorithm}_{i}") for i in range(num_agents)] +elif algorithm == "discrete_maddpg": + policies = [get_maddpg_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)] + trainers = [get_maddpg(state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)] +else: + raise ValueError(f"Unsupported algorithm: {algorithm}") + +# Build RLComponentBundle +rl_component_bundle = RLComponentBundle( + env_sampler=CIMEnvSampler( + learn_env=learn_env, + test_env=test_env, + policies=policies, + agent2policy=agent2policy, + reward_eval_delay=reward_shaping_conf["time_window"], + ), + agent2policy=agent2policy, + policies=policies, + trainers=trainers, +) diff --git a/examples/rl/cim_distributed.yml b/examples/rl/cim_distributed.yml new file mode 100644 index 000000000..3b11cb6e1 --- /dev/null +++ b/examples/rl/cim_distributed.yml @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for CIM scenario. +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +# Run this workflow by executing one of the following commands: +# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml +# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml + +job: cim_rl_workflow +scenario_path: "examples/cim/rl" +log_path: "log/rl_job/cim.txt" +main: + num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training. + num_steps: null + eval_schedule: 5 + logging: + stdout: INFO + file: DEBUG +rollout: + parallelism: + sampling: 3 + eval: null + min_env_samples: 3 + grace_factor: 0.2 + controller: + host: "127.0.0.1" + port: 20000 + logging: + stdout: INFO + file: DEBUG +training: + mode: parallel + load_path: null + load_episode: null + checkpointing: + path: "checkpoint/rl_job/cim" + interval: 5 + proxy: + host: "127.0.0.1" + frontend: 10000 + backend: 10001 + num_workers: 2 + logging: + stdout: INFO + file: DEBUG diff --git a/examples/vm_scheduling/rl/__init__.py b/examples/vm_scheduling/rl/__init__.py index 44e5138a2..90be439f0 100644 --- a/examples/vm_scheduling/rl/__init__.py +++ b/examples/vm_scheduling/rl/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .rl_component_bundle import VMBundle as rl_component_bundle_cls +from .rl_component_bundle import rl_component_bundle __all__ = [ - "rl_component_bundle_cls", + "rl_component_bundle", ] diff --git a/examples/vm_scheduling/rl/algorithms/ac.py b/examples/vm_scheduling/rl/algorithms/ac.py index 411d35d6b..94d0afd63 100644 --- a/examples/vm_scheduling/rl/algorithms/ac.py +++ b/examples/vm_scheduling/rl/algorithms/ac.py @@ -61,9 +61,9 @@ def get_ac_policy(state_dim: int, action_num: int, num_features: int, name: str) def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer: return ActorCriticTrainer( name=name, + reward_discount=0.9, params=ActorCriticParams( get_v_critic_net_func=lambda: MyCriticNet(state_dim, num_features), - reward_discount=0.9, grad_iters=100, critic_loss_cls=torch.nn.MSELoss, min_logp=-20, diff --git a/examples/vm_scheduling/rl/algorithms/dqn.py b/examples/vm_scheduling/rl/algorithms/dqn.py index a94989418..499cb85b5 100644 --- a/examples/vm_scheduling/rl/algorithms/dqn.py +++ b/examples/vm_scheduling/rl/algorithms/dqn.py @@ -77,15 +77,15 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str def get_dqn(name: str) -> DQNTrainer: return DQNTrainer( name=name, + reward_discount=0.9, + replay_memory_capacity=10000, + batch_size=32, + data_parallelism=2, params=DQNParams( - reward_discount=0.9, update_target_every=5, num_epochs=100, soft_update_coef=0.1, double=False, - replay_memory_capacity=10000, random_overwrite=False, - batch_size=32, - data_parallelism=2, ), ) diff --git a/examples/vm_scheduling/rl/env_sampler.py b/examples/vm_scheduling/rl/env_sampler.py index 19d6171a9..3fc39776e 100644 --- a/examples/vm_scheduling/rl/env_sampler.py +++ b/examples/vm_scheduling/rl/env_sampler.py @@ -5,12 +5,13 @@ from collections import defaultdict from os import makedirs from os.path import dirname, join, realpath -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Type, Union import numpy as np from matplotlib import pyplot as plt -from maro.rl.rollout import AbsEnvSampler, CacheElement +from maro.rl.policy import AbsPolicy +from maro.rl.rollout import AbsAgentWrapper, AbsEnvSampler, CacheElement, SimpleAgentWrapper from maro.simulator import Env from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent, PostponeAction @@ -30,8 +31,25 @@ class VMEnvSampler(AbsEnvSampler): - def __init__(self, learn_env: Env, test_env: Env) -> None: - super(VMEnvSampler, self).__init__(learn_env, test_env) + def __init__( + self, + learn_env: Env, + test_env: Env, + policies: List[AbsPolicy], + agent2policy: Dict[Any, str], + trainable_policies: List[str] = None, + agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper, + reward_eval_delay: int = None, + ) -> None: + super(VMEnvSampler, self).__init__( + learn_env, + test_env, + policies, + agent2policy, + trainable_policies, + agent_wrapper_cls, + reward_eval_delay, + ) self._learn_env.set_seed(seed) self._test_env.set_seed(test_seed) diff --git a/examples/vm_scheduling/rl/rl_component_bundle.py b/examples/vm_scheduling/rl/rl_component_bundle.py index 516a253f5..41b32a5bb 100644 --- a/examples/vm_scheduling/rl/rl_component_bundle.py +++ b/examples/vm_scheduling/rl/rl_component_bundle.py @@ -1,67 +1,39 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from functools import partial -from typing import Any, Callable, Dict, Optional - -from maro.rl.policy import AbsPolicy from maro.rl.rl_component.rl_component_bundle import RLComponentBundle -from maro.rl.rollout import AbsEnvSampler -from maro.rl.training import AbsTrainer +from maro.simulator import Env -from examples.vm_scheduling.rl.algorithms.ac import get_ac, get_ac_policy -from examples.vm_scheduling.rl.algorithms.dqn import get_dqn, get_dqn_policy +from .algorithms.ac import get_ac, get_ac_policy +from .algorithms.dqn import get_dqn, get_dqn_policy from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf from examples.vm_scheduling.rl.env_sampler import VMEnvSampler - -class VMBundle(RLComponentBundle): - def get_env_config(self) -> dict: - return env_conf - - def get_test_env_config(self) -> Optional[dict]: - return test_env_conf - - def get_env_sampler(self) -> AbsEnvSampler: - return VMEnvSampler(self.env, self.test_env) - - def get_agent2policy(self) -> Dict[Any, str]: - return {"AGENT": f"{algorithm}.policy"} - - def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]: - action_num = num_pms + 1 # action could be any PM or postponement, hence the plus 1 - - if algorithm == "ac": - policy_creator = { - f"{algorithm}.policy": partial( - get_ac_policy, - state_dim, - action_num, - num_features, - f"{algorithm}.policy", - ), - } - elif algorithm == "dqn": - policy_creator = { - f"{algorithm}.policy": partial( - get_dqn_policy, - state_dim, - action_num, - num_features, - f"{algorithm}.policy", - ), - } - else: - raise ValueError(f"Unsupported algorithm: {algorithm}") - - return policy_creator - - def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]: - if algorithm == "ac": - trainer_creator = {algorithm: partial(get_ac, state_dim, num_features, algorithm)} - elif algorithm == "dqn": - trainer_creator = {algorithm: partial(get_dqn, algorithm)} - else: - raise ValueError(f"Unsupported algorithm: {algorithm}") - - return trainer_creator +# Environments +learn_env = Env(**env_conf) +test_env = Env(**test_env_conf) + +# Agent, policy, and trainers +action_num = num_pms + 1 +agent2policy = {"AGENT": f"{algorithm}.policy"} +if algorithm == "ac": + policies = [get_ac_policy(state_dim, action_num, num_features, f"{algorithm}.policy")] + trainers = [get_ac(state_dim, num_features, algorithm)] +elif algorithm == "dqn": + policies = [get_dqn_policy(state_dim, action_num, num_features, f"{algorithm}.policy")] + trainers = [get_dqn(algorithm)] +else: + raise ValueError(f"Unsupported algorithm: {algorithm}") + +# Build RLComponentBundle +rl_component_bundle = RLComponentBundle( + env_sampler=VMEnvSampler( + learn_env=learn_env, + test_env=test_env, + policies=policies, + agent2policy=agent2policy, + ), + agent2policy=agent2policy, + policies=policies, + trainers=trainers, +) diff --git a/maro/rl/distributed/__init__.py b/maro/rl/distributed/__init__.py index b18d1ee59..828505c04 100644 --- a/maro/rl/distributed/__init__.py +++ b/maro/rl/distributed/__init__.py @@ -3,8 +3,12 @@ from .abs_proxy import AbsProxy from .abs_worker import AbsWorker +from .port_config import DEFAULT_ROLLOUT_PRODUCER_PORT, DEFAULT_TRAINING_BACKEND_PORT, DEFAULT_TRAINING_FRONTEND_PORT __all__ = [ "AbsProxy", "AbsWorker", + "DEFAULT_ROLLOUT_PRODUCER_PORT", + "DEFAULT_TRAINING_FRONTEND_PORT", + "DEFAULT_TRAINING_BACKEND_PORT", ] diff --git a/maro/rl/distributed/abs_worker.py b/maro/rl/distributed/abs_worker.py index 1f191034c..7da7e9435 100644 --- a/maro/rl/distributed/abs_worker.py +++ b/maro/rl/distributed/abs_worker.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. from abc import abstractmethod +from typing import Union import zmq from tornado.ioloop import IOLoop @@ -33,7 +34,7 @@ def __init__( super(AbsWorker, self).__init__() self._id = f"worker.{idx}" - self._logger = logger if logger else DummyLogger() + self._logger: Union[LoggerV2, DummyLogger] = logger if logger else DummyLogger() # ZMQ sockets and streams self._context = Context.instance() diff --git a/maro/rl/distributed/port_config.py b/maro/rl/distributed/port_config.py new file mode 100644 index 000000000..f0828c769 --- /dev/null +++ b/maro/rl/distributed/port_config.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +DEFAULT_ROLLOUT_PRODUCER_PORT = 20000 +DEFAULT_TRAINING_FRONTEND_PORT = 10000 +DEFAULT_TRAINING_BACKEND_PORT = 10001 diff --git a/maro/rl/exploration/scheduling.py b/maro/rl/exploration/scheduling.py index 1276171b1..3981729c9 100644 --- a/maro/rl/exploration/scheduling.py +++ b/maro/rl/exploration/scheduling.py @@ -98,14 +98,15 @@ def __init__( start_ep: int = 1, initial_value: float = None, ) -> None: + super().__init__(exploration_params, param_name, initial_value=initial_value) + # validate splits - splits = [(start_ep, initial_value)] + splits + [(last_ep, final_value)] + splits = [(start_ep, self._exploration_params[self.param_name])] + splits + [(last_ep, final_value)] splits.sort() for (ep, _), (ep2, _) in zip(splits, splits[1:]): if ep == ep2: raise ValueError("The zeroth element of split points must be unique") - super().__init__(exploration_params, param_name, initial_value=initial_value) self.final_value = final_value self._splits = splits self._ep = start_ep diff --git a/maro/rl/model/abs_net.py b/maro/rl/model/abs_net.py index 499eaa1d8..a559d1124 100644 --- a/maro/rl/model/abs_net.py +++ b/maro/rl/model/abs_net.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABCMeta -from typing import Any, Dict, Optional +from typing import Any, Dict import torch.nn from torch.optim import Optimizer @@ -18,7 +18,11 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta): def __init__(self) -> None: super(AbsNet, self).__init__() - self._optim: Optional[Optimizer] = None + @property + def optim(self) -> Optimizer: + optim = getattr(self, "_optim", None) + assert isinstance(optim, Optimizer), "Each AbsNet must have an optimizer" + return optim def step(self, loss: torch.Tensor) -> None: """Run a training step to update the net's parameters according to the given loss. @@ -26,9 +30,9 @@ def step(self, loss: torch.Tensor) -> None: Args: loss (torch.tensor): Loss used to update the model. """ - self._optim.zero_grad() + self.optim.zero_grad() loss.backward() - self._optim.step() + self.optim.step() def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: """Get the gradients with respect to all parameters according to the given loss. @@ -39,7 +43,7 @@ def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: Returns: Gradients (Dict[str, torch.Tensor]): A dict that contains gradients for all parameters. """ - self._optim.zero_grad() + self.optim.zero_grad() loss.backward() return {name: param.grad for name, param in self.named_parameters()} @@ -51,7 +55,7 @@ def apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None: """ for name, param in self.named_parameters(): param.grad = grad[name] - self._optim.step() + self.optim.step() def _forward_unimplemented(self, *input: Any) -> None: pass @@ -64,7 +68,7 @@ def get_state(self) -> dict: """ return { "network": self.state_dict(), - "optim": self._optim.state_dict(), + "optim": self.optim.state_dict(), } def set_state(self, net_state: dict) -> None: @@ -74,7 +78,7 @@ def set_state(self, net_state: dict) -> None: net_state (dict): A dict that contains the net's state. """ self.load_state_dict(net_state["network"]) - self._optim.load_state_dict(net_state["optim"]) + self.optim.load_state_dict(net_state["optim"]) def soft_update(self, other_model: AbsNet, tau: float) -> None: """Soft update the net's parameters according to another net, i.e., diff --git a/maro/rl/model/fc_block.py b/maro/rl/model/fc_block.py index a323dedb0..31108765d 100644 --- a/maro/rl/model/fc_block.py +++ b/maro/rl/model/fc_block.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from collections import OrderedDict -from typing import Any, List, Optional, Type +from typing import Any, List, Optional, Tuple, Type import torch import torch.nn as nn @@ -46,7 +46,7 @@ def __init__( skip_connection: bool = False, dropout_p: float = None, gradient_threshold: float = None, - name: str = None, + name: str = "NONAME", ) -> None: super(FullyConnected, self).__init__() self._input_dim = input_dim @@ -101,12 +101,12 @@ def input_dim(self) -> int: def output_dim(self) -> int: return self._output_dim - def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> torch.nn.Module: + def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> nn.Module: """Build a basic layer. BN -> Linear -> Activation -> Dropout """ - components = [] + components: List[Tuple[str, nn.Module]] = [] if self._batch_norm: components.append(("batch_norm", nn.BatchNorm1d(input_dim))) components.append(("linear", nn.Linear(input_dim, output_dim))) diff --git a/maro/rl/policy/abs_policy.py b/maro/rl/policy/abs_policy.py index c57c0db51..14b0bb3a9 100644 --- a/maro/rl/policy/abs_policy.py +++ b/maro/rl/policy/abs_policy.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -27,14 +27,14 @@ def __init__(self, name: str, trainable: bool) -> None: self._trainable = trainable @abstractmethod - def get_actions(self, states: object) -> object: + def get_actions(self, states: Union[list, np.ndarray]) -> Any: """Get actions according to states. Args: - states (object): States. + states (Union[list, np.ndarray]): States. Returns: - actions (object): Actions. + actions (Any): Actions. """ raise NotImplementedError @@ -79,7 +79,7 @@ class DummyPolicy(AbsPolicy): def __init__(self) -> None: super(DummyPolicy, self).__init__(name="DUMMY_POLICY", trainable=False) - def get_actions(self, states: object) -> None: + def get_actions(self, states: Union[list, np.ndarray]) -> None: return None def explore(self) -> None: @@ -101,11 +101,11 @@ class RuleBasedPolicy(AbsPolicy, metaclass=ABCMeta): def __init__(self, name: str) -> None: super(RuleBasedPolicy, self).__init__(name=name, trainable=False) - def get_actions(self, states: List[object]) -> List[object]: + def get_actions(self, states: list) -> list: return self._rule(states) @abstractmethod - def _rule(self, states: List[object]) -> List[object]: + def _rule(self, states: list) -> list: raise NotImplementedError def explore(self) -> None: @@ -304,7 +304,7 @@ def unfreeze(self) -> None: raise NotImplementedError @abstractmethod - def get_state(self) -> object: + def get_state(self) -> dict: """Get the state of the policy.""" raise NotImplementedError diff --git a/maro/rl/policy/continuous_rl_policy.py b/maro/rl/policy/continuous_rl_policy.py index e93cc982b..33ed3e55d 100644 --- a/maro/rl/policy/continuous_rl_policy.py +++ b/maro/rl/policy/continuous_rl_policy.py @@ -62,12 +62,10 @@ def __init__( ) self._lbounds, self._ubounds = _parse_action_range(self.action_dim, action_range) - assert self._lbounds is not None and self._ubounds is not None - self._policy_net = policy_net @property - def action_bounds(self) -> Tuple[List[float], List[float]]: + def action_bounds(self) -> Tuple[Optional[List[float]], Optional[List[float]]]: return self._lbounds, self._ubounds @property @@ -118,7 +116,7 @@ def eval(self) -> None: def train(self) -> None: self._policy_net.train() - def get_state(self) -> object: + def get_state(self) -> dict: return self._policy_net.get_state() def set_state(self, policy_state: dict) -> None: diff --git a/maro/rl/policy/discrete_rl_policy.py b/maro/rl/policy/discrete_rl_policy.py index a332908dc..567e9d054 100644 --- a/maro/rl/policy/discrete_rl_policy.py +++ b/maro/rl/policy/discrete_rl_policy.py @@ -85,9 +85,11 @@ def __init__( self._exploration_func = exploration_strategy[0] self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing - self._exploration_schedulers = [ - opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options - ] + self._exploration_schedulers = ( + [opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options] + if exploration_scheduling_options is not None + else [] + ) self._call_cnt = 0 self._warmup = warmup @@ -219,7 +221,7 @@ def eval(self) -> None: def train(self) -> None: self._q_net.train() - def get_state(self) -> object: + def get_state(self) -> dict: return self._q_net.get_state() def set_state(self, policy_state: dict) -> None: diff --git a/maro/rl/rl_component/rl_component_bundle.py b/maro/rl/rl_component/rl_component_bundle.py index 772e4afc4..f85fe286b 100644 --- a/maro/rl/rl_component/rl_component_bundle.py +++ b/maro/rl/rl_component/rl_component_bundle.py @@ -1,194 +1,103 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from abc import abstractmethod -from functools import partial -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List -from maro.rl.policy import AbsPolicy +from maro.rl.policy import AbsPolicy, RLPolicy from maro.rl.rollout import AbsEnvSampler from maro.rl.training import AbsTrainer -from maro.simulator import Env -class RLComponentBundle(object): +class RLComponentBundle: """Bundle of all necessary components to run a RL job in MARO. - Users should create their own subclass of `RLComponentBundle` and implement following methods: - - get_env_config() - - get_test_env_config() - - get_env_sampler() - - get_agent2policy() - - get_policy_creator() - - get_trainer_creator() - - Following methods could be overwritten when necessary: - - get_device_mapping() - - Please refer to the doc string of each method for detailed explanations. + env_sampler (AbsEnvSampler): Environment sampler of the scenario. + agent2policy (Dict[Any, str]): Agent name to policy name mapping of the RL job. For example: + {agent1: policy1, agent2: policy1, agent3: policy2}. + policies (List[AbsPolicy]): Policies. + trainers (List[AbsTrainer]): Trainers. + device_mapping (Dict[str, str], default=None): Device mapping that identifying which device to put each policy. + If None, there will be no explicit device assignment. + policy_trainer_mapping (Dict[str, str], default=None): Policy-trainer mapping which identifying which trainer to + train each policy. If None, then a policy's trainer's name is the first segment of the policy's name, + seperated by dot. For example, "ppo_1.policy" is trained by "ppo_1". Only policies that provided in + policy-trainer mapping are considered as trainable polices. Policies that not provided in policy-trainer + mapping will not be trained. """ - def __init__(self) -> None: - super(RLComponentBundle, self).__init__() - - self.trainer_creator: Optional[Dict[str, Callable[[], AbsTrainer]]] = None - - self.agent2policy: Optional[Dict[Any, str]] = None - self.trainable_agent2policy: Optional[Dict[Any, str]] = None - self.policy_creator: Optional[Dict[str, Callable[[], AbsPolicy]]] = None - self.policy_names: Optional[List[str]] = None - self.trainable_policy_creator: Optional[Dict[str, Callable[[], AbsPolicy]]] = None - self.trainable_policy_names: Optional[List[str]] = None - - self.device_mapping: Optional[Dict[str, str]] = None - self.policy_trainer_mapping: Optional[Dict[str, str]] = None - - self._policy_cache: Optional[Dict[str, AbsPolicy]] = None - - # Will be created when `env_sampler()` is first called - self._env_sampler: Optional[AbsEnvSampler] = None - - self._complete_resources() - - ######################################################################################## - # Users MUST implement the following methods # - ######################################################################################## - @abstractmethod - def get_env_config(self) -> dict: - """Return the environment configuration to build the MARO Env for training. - - Returns: - Environment configuration. - """ - raise NotImplementedError - - @abstractmethod - def get_test_env_config(self) -> Optional[dict]: - """Return the environment configuration to build the MARO Env for testing. If returns `None`, the training - environment will be reused as testing environment. - - Returns: - Environment configuration or `None`. - """ - raise NotImplementedError - - @abstractmethod - def get_env_sampler(self) -> AbsEnvSampler: - """Return the environment sampler of the scenario. - - Returns: - The environment sampler of the scenario. - """ - raise NotImplementedError - - @abstractmethod - def get_agent2policy(self) -> Dict[Any, str]: - """Return agent name to policy name mapping of the RL job. This mapping identifies which policy should - the agents use. For example: {agent1: policy1, agent2: policy1, agent3: policy2}. - - Returns: - Agent name to policy name mapping. - """ - raise NotImplementedError - - @abstractmethod - def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]: - """Return policy creator. Policy creator is a dictionary that contains a group of functions that generate - policy instances. The key of this dictionary is the policy name, and the value is the function that generate - the corresponding policy instance. Note that the creation function should not take any parameters. - """ - raise NotImplementedError - - @abstractmethod - def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]: - """Return trainer creator. Trainer creator is similar to policy creator, but is used to creator trainers.""" - raise NotImplementedError - - ######################################################################################## - # Users could overwrite the following methods # - ######################################################################################## - def get_device_mapping(self) -> Dict[str, str]: - """Return the device mapping that identifying which device to put each policy. - - If user does not overwrite this method, then all policies will be put on CPU by default. - """ - return {policy_name: "cpu" for policy_name in self.get_policy_creator()} - - def get_policy_trainer_mapping(self) -> Dict[str, str]: - """Return the policy-trainer mapping which identifying which trainer to train each policy. - - If user does not overwrite this method, then a policy's trainer's name is the first segment of the policy's - name, seperated by dot. For example, "ppo_1.policy" is trained by "ppo_1". - - Only policies that provided in policy-trainer mapping are considered as trainable polices. Policies that - not provided in policy-trainer mapping will not be trained since we do not assign a trainer to it. - """ - return {policy_name: policy_name.split(".")[0] for policy_name in self.policy_creator} - - ######################################################################################## - # Methods invisible to users # - ######################################################################################## - @property - def env_sampler(self) -> AbsEnvSampler: - if self._env_sampler is None: - self._env_sampler = self.get_env_sampler() - self._env_sampler.build(self) - return self._env_sampler - - def _complete_resources(self) -> None: - """Generate all attributes by calling user-defined logics. Do necessary checking and transformations.""" - env_config = self.get_env_config() - test_env_config = self.get_test_env_config() - self.env = Env(**env_config) - self.test_env = self.env if test_env_config is None else Env(**test_env_config) - - self.trainer_creator = self.get_trainer_creator() - self.device_mapping = self.get_device_mapping() - - self.policy_creator = self.get_policy_creator() - self.agent2policy = self.get_agent2policy() - - self.policy_trainer_mapping = self.get_policy_trainer_mapping() - - required_policies = set(self.agent2policy.values()) - self.policy_creator = {name: self.policy_creator[name] for name in required_policies} + def __init__( + self, + env_sampler: AbsEnvSampler, + agent2policy: Dict[Any, str], + policies: List[AbsPolicy], + trainers: List[AbsTrainer], + device_mapping: Dict[str, str] = None, + policy_trainer_mapping: Dict[str, str] = None, + ) -> None: + self.env_sampler = env_sampler + self.agent2policy = agent2policy + self.policies = policies + self.trainers = trainers + + policy_set = set([policy.name for policy in self.policies]) + not_found = [policy_name for policy_name in self.agent2policy.values() if policy_name not in policy_set] + if len(not_found) > 0: + raise ValueError(f"The following policies are required but cannot be found: [{', '.join(not_found)}]") + + # Remove unused policies + kept_policies = [] + for policy in self.policies: + if policy.name not in self.agent2policy.values(): + raise Warning(f"Policy {policy.name} is removed since it is not used by any agent.") + else: + kept_policies.append(policy) + self.policies = kept_policies + policy_set = set([policy.name for policy in self.policies]) + + self.device_mapping = ( + {k: v for k, v in device_mapping.items() if k in policy_set} if device_mapping is not None else {} + ) + self.policy_trainer_mapping = ( + policy_trainer_mapping + if policy_trainer_mapping is not None + else {policy_name: policy_name.split(".")[0] for policy_name in policy_set} + ) + + # Check missing trainers self.policy_trainer_mapping = { - name: self.policy_trainer_mapping[name] for name in required_policies if name in self.policy_trainer_mapping + policy_name: trainer_name + for policy_name, trainer_name in self.policy_trainer_mapping.items() + if policy_name in policy_set } - self.policy_names = list(required_policies) - assert len(required_policies) == len(self.policy_creator) # Should have same size after filter + trainer_set = set([trainer.name for trainer in self.trainers]) + not_found = [ + trainer_name for trainer_name in self.policy_trainer_mapping.values() if trainer_name not in trainer_set + ] + if len(not_found) > 0: + raise ValueError(f"The following trainers are required but cannot be found: [{', '.join(not_found)}]") + + # Remove unused trainers + kept_trainers = [] + for trainer in self.trainers: + if trainer.name not in self.policy_trainer_mapping.values(): + raise Warning(f"Trainer {trainer.name} if removed since no policy is trained by it.") + else: + kept_trainers.append(trainer) + self.trainers = kept_trainers - required_trainers = set(self.policy_trainer_mapping.values()) - self.trainer_creator = {name: self.trainer_creator[name] for name in required_trainers} - assert len(required_trainers) == len(self.trainer_creator) # Should have same size after filter - - self.trainable_policy_names = list(self.policy_trainer_mapping.keys()) - self.trainable_policy_creator = { - policy_name: self.policy_creator[policy_name] for policy_name in self.trainable_policy_names - } - self.trainable_agent2policy = { + @property + def trainable_agent2policy(self) -> Dict[Any, str]: + return { agent_name: policy_name for agent_name, policy_name in self.agent2policy.items() - if policy_name in self.trainable_policy_names + if policy_name in self.policy_trainer_mapping } - def pre_create_policy_instances(self) -> None: - """Pre-create policy instances, and return the pre-created policy instances when the external callers - want to create new policies. This will ensure that each policy will have at most one reusable duplicate. - Under specific scenarios (for example, simple training & rollout), this will reduce unnecessary overheads. - """ - old_policy_creator = self.policy_creator - self._policy_cache: Dict[str, AbsPolicy] = {} - for policy_name in self.policy_names: - self._policy_cache[policy_name] = old_policy_creator[policy_name]() - - def _get_policy_instance(policy_name: str) -> AbsPolicy: - return self._policy_cache[policy_name] - - self.policy_creator = { - policy_name: partial(_get_policy_instance, policy_name) for policy_name in self.policy_names - } - - self.trainable_policy_creator = { - policy_name: self.policy_creator[policy_name] for policy_name in self.trainable_policy_names - } + @property + def trainable_policies(self) -> List[RLPolicy]: + policies = [] + for policy in self.policies: + if policy.name in self.policy_trainer_mapping: + assert isinstance(policy, RLPolicy) + policies.append(policy) + return policies diff --git a/maro/rl/rollout/batch_env_sampler.py b/maro/rl/rollout/batch_env_sampler.py index 0a184b28b..a3504a156 100644 --- a/maro/rl/rollout/batch_env_sampler.py +++ b/maro/rl/rollout/batch_env_sampler.py @@ -4,12 +4,13 @@ import os import time from itertools import chain -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch import zmq from zmq import Context, Poller +from maro.rl.distributed import DEFAULT_ROLLOUT_PRODUCER_PORT from maro.rl.utils.common import bytes_to_pyobj, get_own_ip_address, pyobj_to_bytes from maro.rl.utils.objects import FILE_SUFFIX from maro.utils import DummyLogger, LoggerV2 @@ -37,19 +38,19 @@ def __init__(self, port: int = 20000, logger: LoggerV2 = None) -> None: self._poller = Poller() self._poller.register(self._task_endpoint, zmq.POLLIN) - self._workers = set() - self._logger = logger + self._workers: set = set() + self._logger: Union[DummyLogger, LoggerV2] = logger if logger is not None else DummyLogger() def _wait_for_workers_ready(self, k: int) -> None: while len(self._workers) < k: self._workers.add(self._task_endpoint.recv_multipart()[0]) - def _recv_result_for_target_index(self, index: int) -> object: + def _recv_result_for_target_index(self, index: int) -> Any: rep = bytes_to_pyobj(self._task_endpoint.recv_multipart()[-1]) assert isinstance(rep, dict) return rep["result"] if rep["index"] == index else None - def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: int = None) -> List[dict]: + def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: float = None) -> List[dict]: """Send a task request to a set of remote workers and collect the results. Args: @@ -70,7 +71,7 @@ def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_fa min_replies = parallelism start_time = time.time() - results = [] + results: list = [] for worker_id in list(self._workers)[:parallelism]: self._task_endpoint.send_multipart([worker_id, pyobj_to_bytes(req)]) self._logger.debug(f"Sent {parallelism} roll-out requests...") @@ -81,7 +82,7 @@ def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_fa results.append(result) if grace_factor is not None: - countdown = int((time.time() - start_time) * grace_factor) * 1000 # milliseconds + countdown = int((time.time() - start_time) * grace_factor) * 1000.0 # milliseconds self._logger.debug(f"allowing {countdown / 1000} seconds for remaining results") while len(results) < parallelism and countdown > 0: start = time.time() @@ -125,15 +126,18 @@ class BatchEnvSampler: def __init__( self, sampling_parallelism: int, - port: int = 20000, + port: int = None, min_env_samples: int = None, grace_factor: float = None, eval_parallelism: int = None, logger: LoggerV2 = None, ) -> None: super(BatchEnvSampler, self).__init__() - self._logger = logger if logger else DummyLogger() - self._controller = ParallelTaskController(port=port, logger=logger) + self._logger: Union[LoggerV2, DummyLogger] = logger if logger is not None else DummyLogger() + self._controller = ParallelTaskController( + port=port if port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT, + logger=logger, + ) self._sampling_parallelism = 1 if sampling_parallelism is None else sampling_parallelism self._min_env_samples = min_env_samples if min_env_samples is not None else self._sampling_parallelism @@ -143,11 +147,15 @@ def __init__( self._ep = 0 self._end_of_episode = True - def sample(self, policy_state: Optional[Dict[str, object]] = None, num_steps: Optional[int] = None) -> dict: + def sample( + self, + policy_state: Optional[Dict[str, Dict[str, Any]]] = None, + num_steps: Optional[int] = None, + ) -> dict: """Collect experiences from a set of remote roll-out workers. Args: - policy_state (Dict[str, object]): Policy state dict. If it is not None, then we need to update all + policy_state (Dict[str, Any]): Policy state dict. If it is not None, then we need to update all policies according to the latest policy states, then start the experience collection. num_steps (Optional[int], default=None): Number of environment steps to collect experiences for. If it is None, interactions with the (remote) environments will continue until the terminal state is @@ -181,7 +189,7 @@ def sample(self, policy_state: Optional[Dict[str, object]] = None, num_steps: Op "info": [res["info"][0] for res in results], } - def eval(self, policy_state: Dict[str, object] = None) -> dict: + def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict: req = {"type": "eval", "policy_state": policy_state, "index": self._ep} # -1 signals test results = self._controller.collect(req, self._eval_parallelism) return { @@ -209,3 +217,11 @@ def load_policy_state(self, path: str) -> List[str]: def exit(self) -> None: self._controller.exit() + + def post_collect(self, info_list: list, ep: int) -> None: + req = {"type": "post_collect", "info_list": info_list, "index": ep} + self._controller.collect(req, 1) + + def post_evaluate(self, info_list: list, ep: int) -> None: + req = {"type": "post_evaluate", "info_list": info_list, "index": ep} + self._controller.collect(req, 1) diff --git a/maro/rl/rollout/env_sampler.py b/maro/rl/rollout/env_sampler.py index d05cce494..81e6ccdad 100644 --- a/maro/rl/rollout/env_sampler.py +++ b/maro/rl/rollout/env_sampler.py @@ -5,7 +5,6 @@ import collections import os -import typing from abc import ABCMeta, abstractmethod from copy import deepcopy from dataclasses import dataclass @@ -18,9 +17,6 @@ from maro.rl.utils.objects import FILE_SUFFIX from maro.simulator import Env -if typing.TYPE_CHECKING: - from maro.rl.rl_component.rl_component_bundle import RLComponentBundle - class AbsAgentWrapper(object, metaclass=ABCMeta): """Agent wrapper. Used to manager agents & policies during experience collection. @@ -51,16 +47,16 @@ def set_policy_state(self, policy_state_dict: Dict[str, dict]) -> None: def choose_actions( self, - state_by_agent: Dict[Any, Union[np.ndarray, List[object]]], - ) -> Dict[Any, Union[np.ndarray, List[object]]]: + state_by_agent: Dict[Any, Union[np.ndarray, list]], + ) -> Dict[Any, Union[np.ndarray, list]]: """Choose action according to the given (observable) states of all agents. Args: - state_by_agent (Dict[Any, Union[np.ndarray, List[object]]]): Dictionary containing each agent's states. + state_by_agent (Dict[Any, Union[np.ndarray, list]]): Dictionary containing each agent's states. If the policy is a `RLPolicy`, its state is a Numpy array. Otherwise, its state is a list of objects. Returns: - actions (Dict[Any, Union[np.ndarray, List[object]]]): Dict that contains the action for all agents. + actions (Dict[Any, Union[np.ndarray, list]]): Dict that contains the action for all agents. If the policy is a `RLPolicy`, its action is a Numpy array. Otherwise, its action is a list of objects. """ self.switch_to_eval_mode() @@ -71,8 +67,8 @@ def choose_actions( @abstractmethod def _choose_actions_impl( self, - state_by_agent: Dict[Any, Union[np.ndarray, List[object]]], - ) -> Dict[Any, Union[np.ndarray, List[object]]]: + state_by_agent: Dict[Any, Union[np.ndarray, list]], + ) -> Dict[Any, Union[np.ndarray, list]]: """Implementation of `choose_actions`.""" raise NotImplementedError @@ -95,15 +91,15 @@ def switch_to_eval_mode(self) -> None: class SimpleAgentWrapper(AbsAgentWrapper): def __init__( self, - policy_dict: Dict[str, RLPolicy], # {policy_name: RLPolicy} + policy_dict: Dict[str, AbsPolicy], # {policy_name: AbsPolicy} agent2policy: Dict[Any, str], # {agent_name: policy_name} ) -> None: super(SimpleAgentWrapper, self).__init__(policy_dict=policy_dict, agent2policy=agent2policy) def _choose_actions_impl( self, - state_by_agent: Dict[Any, Union[np.ndarray, List[object]]], - ) -> Dict[Any, Union[np.ndarray, List[object]]]: + state_by_agent: Dict[Any, Union[np.ndarray, list]], + ) -> Dict[Any, Union[np.ndarray, list]]: # Aggregate states by policy states_by_policy = collections.defaultdict(list) # {str: list of np.ndarray} agents_by_policy = collections.defaultdict(list) # {str: list of str} @@ -112,15 +108,15 @@ def _choose_actions_impl( states_by_policy[policy_name].append(state) agents_by_policy[policy_name].append(agent_name) - action_dict = {} + action_dict: dict = {} for policy_name in agents_by_policy: policy = self._policy_dict[policy_name] if isinstance(policy, RLPolicy): states = np.vstack(states_by_policy[policy_name]) # np.ndarray else: - states = states_by_policy[policy_name] # List[object] - actions = policy.get_actions(states) # np.ndarray or List[object] + states = states_by_policy[policy_name] # list + actions: Union[np.ndarray, list] = policy.get_actions(states) # np.ndarray or list action_dict.update(zip(agents_by_policy[policy_name], actions)) return action_dict @@ -188,7 +184,7 @@ def split_contents_by_trainer(self, agent2trainer: Dict[Any, str]) -> Dict[str, Contents (Dict[str, ExpElement]): A dict that contains the ExpElements of all trainers. The key of this dict is the trainer name. """ - ret = collections.defaultdict( + ret: Dict[str, ExpElement] = collections.defaultdict( lambda: ExpElement( tick=self.tick, state=self.state, @@ -213,7 +209,7 @@ def split_contents_by_trainer(self, agent2trainer: Dict[Any, str]) -> Dict[str, @dataclass class CacheElement(ExpElement): - event: object + event: Any env_action_dict: Dict[Any, np.ndarray] def make_exp_element(self) -> ExpElement: @@ -238,6 +234,9 @@ class AbsEnvSampler(object, metaclass=ABCMeta): Args: learn_env (Env): Environment used for training. test_env (Env): Environment used for testing. + policies (List[AbsPolicy]): List of policies. + agent2policy (Dict[Any, str]): Agent name to policy name mapping of the RL job. + trainable_policies (List[str]): Name of trainable policies. agent_wrapper_cls (Type[AbsAgentWrapper], default=SimpleAgentWrapper): Specific AgentWrapper type. reward_eval_delay (int, default=None): Number of ticks required after a decision event to evaluate the reward for the action taken for that event. If it is None, calculate reward immediately after `step()`. @@ -247,6 +246,9 @@ def __init__( self, learn_env: Env, test_env: Env, + policies: List[AbsPolicy], + agent2policy: Dict[Any, str], + trainable_policies: List[str] = None, agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper, reward_eval_delay: int = None, ) -> None: @@ -255,7 +257,7 @@ def __init__( self._agent_wrapper_cls = agent_wrapper_cls - self._event = None + self._event: Optional[list] = None self._end_of_episode = True self._state: Optional[np.ndarray] = None self._agent_state_dict: Dict[Any, np.ndarray] = {} @@ -264,31 +266,23 @@ def __init__( self._agent_last_index: Dict[Any, int] = {} # Index of last occurrence of agent in self._trans_cache self._reward_eval_delay = reward_eval_delay - self._info = {} + self._info: dict = {} assert self._reward_eval_delay is None or self._reward_eval_delay >= 0 - def build( - self, - rl_component_bundle: RLComponentBundle, - ) -> None: - """ - Args: - rl_component_bundle (RLComponentBundle): The RL component bundle of the job. - """ + # self._env: Optional[Env] = None - - self._policy_dict = { - policy_name: rl_component_bundle.policy_creator[policy_name]() - for policy_name in rl_component_bundle.policy_names - } - + self._policy_dict: Dict[str, AbsPolicy] = {policy.name: policy for policy in policies} self._rl_policy_dict: Dict[str, RLPolicy] = { - name: policy for name, policy in self._policy_dict.items() if isinstance(policy, RLPolicy) + policy.name: policy for policy in policies if isinstance(policy, RLPolicy) } - self._agent2policy = rl_component_bundle.agent2policy + self._agent2policy = agent2policy self._agent_wrapper = self._agent_wrapper_cls(self._policy_dict, self._agent2policy) - self._trainable_policies = set(rl_component_bundle.trainable_policy_names) + + if trainable_policies is not None: + self._trainable_policies = trainable_policies + else: + self._trainable_policies = list(self._policy_dict.keys()) # Default: all policies are trainable self._trainable_agents = { agent_id for agent_id, policy_name in self._agent2policy.items() if policy_name in self._trainable_policies } @@ -297,23 +291,31 @@ def build( [policy_name in self._rl_policy_dict for policy_name in self._trainable_policies], ), "All trainable policies must be RL policies!" + @property + def env(self) -> Env: + assert self._env is not None + return self._env + + def _switch_env(self, env: Env) -> None: + self._env = env + def assign_policy_to_device(self, policy_name: str, device: torch.device) -> None: self._rl_policy_dict[policy_name].to_device(device) def _get_global_and_agent_state( self, - event: object, + event: Any, tick: int = None, - ) -> Tuple[Optional[object], Dict[Any, Union[np.ndarray, List[object]]]]: + ) -> Tuple[Optional[Any], Dict[Any, Union[np.ndarray, list]]]: """Get the global and individual agents' states. Args: - event (object): Event. + event (Any): Event. tick (int, default=None): Current tick. Returns: - Global state (Optional[object]) - Dict of agent states (Dict[Any, Union[np.ndarray, List[object]]]). If the policy is a `RLPolicy`, + Global state (Optional[Any]) + Dict of agent states (Dict[Any, Union[np.ndarray, list]]). If the policy is a `RLPolicy`, its state is a Numpy array. Otherwise, its state is a list of objects. """ global_state, agent_state_dict = self._get_global_and_agent_state_impl(event, tick) @@ -327,23 +329,23 @@ def _get_global_and_agent_state( @abstractmethod def _get_global_and_agent_state_impl( self, - event: object, + event: Any, tick: int = None, - ) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]: + ) -> Tuple[Union[None, np.ndarray, list], Dict[Any, Union[np.ndarray, list]]]: raise NotImplementedError @abstractmethod def _translate_to_env_action( self, - action_dict: Dict[Any, Union[np.ndarray, List[object]]], - event: object, - ) -> Dict[Any, object]: + action_dict: Dict[Any, Union[np.ndarray, list]], + event: Any, + ) -> dict: """Translate model-generated actions into an object that can be executed by the env. Args: - action_dict (Dict[Any, Union[np.ndarray, List[object]]]): Action for all agents. If the policy is a + action_dict (Dict[Any, Union[np.ndarray, list]]): Action for all agents. If the policy is a `RLPolicy`, its (input) action is a Numpy array. Otherwise, its (input) action is a list of objects. - event (object): Decision event. + event (Any): Decision event. Returns: A dict that contains env actions for all agents. @@ -351,12 +353,12 @@ def _translate_to_env_action( raise NotImplementedError @abstractmethod - def _get_reward(self, env_action_dict: Dict[Any, object], event: object, tick: int) -> Dict[Any, float]: + def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]: """Get rewards according to the env actions. Args: - env_action_dict (Dict[Any, object]): Dict that contains env actions for all agents. - event (object): Decision event. + env_action_dict (dict): Dict that contains env actions for all agents. + event (Any): Decision event. tick (int): Current tick. Returns: @@ -365,7 +367,7 @@ def _get_reward(self, env_action_dict: Dict[Any, object], event: object, tick: i raise NotImplementedError def _step(self, actions: Optional[list]) -> None: - _, self._event, self._end_of_episode = self._env.step(actions) + _, self._event, self._end_of_episode = self.env.step(actions) self._state, self._agent_state_dict = ( (None, {}) if self._end_of_episode else self._get_global_and_agent_state(self._event) ) @@ -403,7 +405,7 @@ def _append_cache_element(self, cache_element: Optional[CacheElement]) -> None: self._agent_last_index[agent_name] = cur_index def _reset(self) -> None: - self._env.reset() + self.env.reset() self._info.clear() self._trans_cache.clear() self._agent_last_index.clear() @@ -412,7 +414,11 @@ def _reset(self) -> None: def _select_trainable_agents(self, original_dict: dict) -> dict: return {k: v for k, v in original_dict.items() if k in self._trainable_agents} - def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Optional[int] = None) -> dict: + def sample( + self, + policy_state: Optional[Dict[str, Dict[str, Any]]] = None, + num_steps: Optional[int] = None, + ) -> dict: """Sample experiences. Args: @@ -425,7 +431,7 @@ def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Opti A dict that contains the collected experiences and additional information. """ # Init the env - self._env = self._learn_env + self._switch_env(self._learn_env) if self._end_of_episode: self._reset() @@ -443,7 +449,7 @@ def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Opti # Store experiences in the cache cache_element = CacheElement( - tick=self._env.tick, + tick=self.env.tick, event=self._event, state=self._state, agent_state_dict=self._select_trainable_agents(self._agent_state_dict), @@ -466,7 +472,7 @@ def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Opti steps_to_go -= 1 self._append_cache_element(None) - tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) + tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) experiences: List[ExpElement] = [] while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound: cache_element = self._trans_cache.pop(0) @@ -508,8 +514,8 @@ def load_policy_state(self, path: str) -> List[str]: return loaded - def eval(self, policy_state: Dict[str, dict] = None) -> dict: - self._env = self._test_env + def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict: + self._switch_env(self._test_env) self._reset() if policy_state is not None: self.set_policy_state(policy_state) @@ -521,7 +527,7 @@ def eval(self, policy_state: Dict[str, dict] = None) -> dict: # Store experiences in the cache cache_element = CacheElement( - tick=self._env.tick, + tick=self.env.tick, event=self._event, state=self._state, agent_state_dict=self._select_trainable_agents(self._agent_state_dict), @@ -544,7 +550,7 @@ def eval(self, policy_state: Dict[str, dict] = None) -> dict: self._append_cache_element(cache_element) self._append_cache_element(None) - tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) + tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound: cache_element = self._trans_cache.pop(0) if self._reward_eval_delay is not None: diff --git a/maro/rl/rollout/worker.py b/maro/rl/rollout/worker.py index 1067eb5c3..b8301ee38 100644 --- a/maro/rl/rollout/worker.py +++ b/maro/rl/rollout/worker.py @@ -5,7 +5,7 @@ import typing -from maro.rl.distributed import AbsWorker +from maro.rl.distributed import DEFAULT_ROLLOUT_PRODUCER_PORT, AbsWorker from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes from maro.utils import LoggerV2 @@ -19,7 +19,7 @@ class RolloutWorker(AbsWorker): Args: idx (int): Integer identifier for the worker. It is used to generate an internal ID, "worker.{idx}", so that the parallel roll-out controller can keep track of its connection status. - rl_component_bundle (RLComponentBundle): The RL component bundle of the job. + rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow. producer_host (str): IP address of the parallel task controller host to connect to. producer_port (int, default=20000): Port of the parallel task controller host to connect to. logger (LoggerV2, default=None): The logger of the workflow. @@ -30,13 +30,13 @@ def __init__( idx: int, rl_component_bundle: RLComponentBundle, producer_host: str, - producer_port: int = 20000, + producer_port: int = None, logger: LoggerV2 = None, ) -> None: super(RolloutWorker, self).__init__( idx=idx, producer_host=producer_host, - producer_port=producer_port, + producer_port=producer_port if producer_port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT, logger=logger, ) self._env_sampler = rl_component_bundle.env_sampler @@ -53,13 +53,20 @@ def _compute(self, msg: list) -> None: else: req = bytes_to_pyobj(msg[-1]) assert isinstance(req, dict) - assert req["type"] in {"sample", "eval", "set_policy_state"} - if req["type"] == "sample": - result = self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"]) - elif req["type"] == "eval": - result = self._env_sampler.eval(policy_state=req["policy_state"]) - else: - self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"]) - result = True + assert req["type"] in {"sample", "eval", "set_policy_state", "post_collect", "post_evaluate"} - self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]})) + if req["type"] in ("sample", "eval"): + result = ( + self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"]) + if req["type"] == "sample" + else self._env_sampler.eval(policy_state=req["policy_state"]) + ) + self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]})) + else: + if req["type"] == "set_policy_state": + self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"]) + elif req["type"] == "post_collect": + self._env_sampler.post_collect(info_list=req["info_list"], ep=req["index"]) + else: + self._env_sampler.post_evaluate(info_list=req["info_list"], ep=req["index"]) + self._stream.send(pyobj_to_bytes({"result": True, "index": req["index"]})) diff --git a/maro/rl/training/__init__.py b/maro/rl/training/__init__.py index 3f2d01a4c..a77296f98 100644 --- a/maro/rl/training/__init__.py +++ b/maro/rl/training/__init__.py @@ -4,7 +4,7 @@ from .proxy import TrainingProxy from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory from .train_ops import AbsTrainOps, RemoteOps, remote -from .trainer import AbsTrainer, MultiAgentTrainer, SingleAgentTrainer, TrainerParams +from .trainer import AbsTrainer, BaseTrainerParams, MultiAgentTrainer, SingleAgentTrainer from .training_manager import TrainingManager from .worker import TrainOpsWorker @@ -18,9 +18,9 @@ "RemoteOps", "remote", "AbsTrainer", + "BaseTrainerParams", "MultiAgentTrainer", "SingleAgentTrainer", - "TrainerParams", "TrainingManager", "TrainOpsWorker", ] diff --git a/maro/rl/training/algorithms/ac.py b/maro/rl/training/algorithms/ac.py index 2f9d576e2..4486daee3 100644 --- a/maro/rl/training/algorithms/ac.py +++ b/maro/rl/training/algorithms/ac.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Dict from maro.rl.training.algorithms.base import ACBasedParams, ACBasedTrainer @@ -13,18 +12,8 @@ class ActorCriticParams(ACBasedParams): for detailed information. """ - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_v_critic_net_func": self.get_v_critic_net_func, - "reward_discount": self.reward_discount, - "critic_loss_cls": self.critic_loss_cls, - "lam": self.lam, - "min_logp": self.min_logp, - "is_discrete_action": self.is_discrete_action, - } - def __post_init__(self) -> None: - assert self.get_v_critic_net_func is not None + assert self.clip_ratio is None class ActorCriticTrainer(ACBasedTrainer): @@ -34,5 +23,20 @@ class ActorCriticTrainer(ACBasedTrainer): https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/vpg """ - def __init__(self, name: str, params: ActorCriticParams) -> None: - super(ActorCriticTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: ActorCriticParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(ActorCriticTrainer, self).__init__( + name, + params, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) diff --git a/maro/rl/training/algorithms/base/ac_ppo_base.py b/maro/rl/training/algorithms/base/ac_ppo_base.py index 69b23f573..3227437be 100644 --- a/maro/rl/training/algorithms/base/ac_ppo_base.py +++ b/maro/rl/training/algorithms/base/ac_ppo_base.py @@ -3,19 +3,19 @@ from abc import ABCMeta from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, cast import numpy as np import torch from maro.rl.model import VNet from maro.rl.policy import ContinuousRLPolicy, DiscretePolicyGradient, RLPolicy -from maro.rl.training import AbsTrainOps, FIFOReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, FIFOReplayMemory, RemoteOps, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, discount_cumsum, get_torch_device, ndarray_to_tensor @dataclass -class ACBasedParams(TrainerParams, metaclass=ABCMeta): +class ACBasedParams(BaseTrainerParams, metaclass=ABCMeta): """ Parameter bundle for Actor-Critic based algorithms (Actor-Critic & PPO) @@ -23,18 +23,16 @@ class ACBasedParams(TrainerParams, metaclass=ABCMeta): grad_iters (int, default=1): Number of iterations to calculate gradients. critic_loss_cls (Callable, default=None): Critic loss function. If it is None, use MSE. lam (float, default=0.9): Lambda value for generalized advantage estimation (TD-Lambda). - min_logp (float, default=None): Lower bound for clamping logP values during learning. + min_logp (float, default=float("-inf")): Lower bound for clamping logP values during learning. This is to prevent logP from becoming very large in magnitude and causing stability issues. - If it is None, it means no lower bound. - is_discrete_action (bool, default=True): Indicator of continuous or discrete action policy. """ - get_v_critic_net_func: Callable[[], VNet] = None + get_v_critic_net_func: Callable[[], VNet] grad_iters: int = 1 - critic_loss_cls: Callable = None + critic_loss_cls: Optional[Callable] = None lam: float = 0.9 - min_logp: Optional[float] = None - is_discrete_action: bool = True + min_logp: float = float("-inf") + clip_ratio: Optional[float] = None class ACBasedOps(AbsTrainOps): @@ -43,33 +41,26 @@ class ACBasedOps(AbsTrainOps): def __init__( self, name: str, - policy_creator: Callable[[], RLPolicy], - get_v_critic_net_func: Callable[[], VNet], - parallelism: int = 1, + policy: RLPolicy, + params: ACBasedParams, reward_discount: float = 0.9, - critic_loss_cls: Callable = None, - clip_ratio: float = None, - lam: float = 0.9, - min_logp: float = None, - is_discrete_action: bool = True, + parallelism: int = 1, ) -> None: super(ACBasedOps, self).__init__( name=name, - policy_creator=policy_creator, + policy=policy, parallelism=parallelism, ) - assert isinstance(self._policy, DiscretePolicyGradient) or isinstance(self._policy, ContinuousRLPolicy) + assert isinstance(self._policy, (ContinuousRLPolicy, DiscretePolicyGradient)) self._reward_discount = reward_discount - self._critic_loss_func = critic_loss_cls() if critic_loss_cls is not None else torch.nn.MSELoss() - self._clip_ratio = clip_ratio - self._lam = lam - self._min_logp = min_logp - self._v_critic_net = get_v_critic_net_func() - self._is_discrete_action = is_discrete_action - - self._device = None + self._critic_loss_func = params.critic_loss_cls() if params.critic_loss_cls is not None else torch.nn.MSELoss() + self._clip_ratio = params.clip_ratio + self._lam = params.lam + self._min_logp = params.min_logp + self._v_critic_net = params.get_v_critic_net_func() + self._is_discrete_action = isinstance(self._policy, DiscretePolicyGradient) def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor: """Compute the critic loss of the batch. @@ -249,14 +240,32 @@ class ACBasedTrainer(SingleAgentTrainer): https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f """ - def __init__(self, name: str, params: ACBasedParams) -> None: - super(ACBasedTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: ACBasedParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(ACBasedTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params + def _register_policy(self, policy: RLPolicy) -> None: + assert isinstance(policy, (ContinuousRLPolicy, DiscretePolicyGradient)) + self._policy = policy + def build(self) -> None: - self._ops = self.get_ops() + self._ops = cast(ACBasedOps, self.get_ops()) self._replay_memory = FIFOReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._ops.policy_state_dim, action_dim=self._ops.policy_action_dim, ) @@ -266,10 +275,11 @@ def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatc def get_local_ops(self) -> AbsTrainOps: return ACBasedOps( - name=self._policy_name, - policy_creator=self._policy_creator, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + name=self._policy.name, + policy=self._policy, + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def _get_batch(self) -> TransitionBatch: diff --git a/maro/rl/training/algorithms/ddpg.py b/maro/rl/training/algorithms/ddpg.py index 2c95ba7c4..79bd5b336 100644 --- a/maro/rl/training/algorithms/ddpg.py +++ b/maro/rl/training/algorithms/ddpg.py @@ -2,19 +2,19 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Callable, Dict +from typing import Callable, Dict, Optional, cast import torch from maro.rl.model import QNet from maro.rl.policy import ContinuousRLPolicy, RLPolicy -from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone @dataclass -class DDPGParams(TrainerParams): +class DDPGParams(BaseTrainerParams): """ get_q_critic_net_func (Callable[[], QNet]): Function to get Q critic net. num_epochs (int, default=1): Number of training epochs per call to ``learn``. @@ -30,25 +30,14 @@ class DDPGParams(TrainerParams): min_num_to_trigger_training (int, default=0): Minimum number required to start training. """ - get_q_critic_net_func: Callable[[], QNet] = None + get_q_critic_net_func: Callable[[], QNet] num_epochs: int = 1 update_target_every: int = 5 - q_value_loss_cls: Callable = None + q_value_loss_cls: Optional[Callable] = None soft_update_coef: float = 1.0 random_overwrite: bool = False min_num_to_trigger_training: int = 0 - def __post_init__(self) -> None: - assert self.get_q_critic_net_func is not None - - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_q_critic_net_func": self.get_q_critic_net_func, - "reward_discount": self.reward_discount, - "q_value_loss_cls": self.q_value_loss_cls, - "soft_update_coef": self.soft_update_coef, - } - class DDPGOps(AbsTrainOps): """DDPG algorithm implementation. Reference: https://spinningup.openai.com/en/latest/algorithms/ddpg.html""" @@ -56,31 +45,31 @@ class DDPGOps(AbsTrainOps): def __init__( self, name: str, - policy_creator: Callable[[], RLPolicy], - get_q_critic_net_func: Callable[[], QNet], - reward_discount: float, + policy: RLPolicy, + params: DDPGParams, + reward_discount: float = 0.9, parallelism: int = 1, - q_value_loss_cls: Callable = None, - soft_update_coef: float = 1.0, ) -> None: super(DDPGOps, self).__init__( name=name, - policy_creator=policy_creator, + policy=policy, parallelism=parallelism, ) assert isinstance(self._policy, ContinuousRLPolicy) - self._target_policy = clone(self._policy) + self._target_policy: ContinuousRLPolicy = clone(self._policy) self._target_policy.set_name(f"target_{self._policy.name}") self._target_policy.eval() - self._q_critic_net = get_q_critic_net_func() + self._q_critic_net = params.get_q_critic_net_func() self._target_q_critic_net: QNet = clone(self._q_critic_net) self._target_q_critic_net.eval() self._reward_discount = reward_discount - self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() - self._soft_update_coef = soft_update_coef + self._q_value_loss_func = ( + params.q_value_loss_cls() if params.q_value_loss_cls is not None else torch.nn.MSELoss() + ) + self._soft_update_coef = params.soft_update_coef def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor: """Compute the critic loss of the batch. @@ -207,7 +196,7 @@ def soft_update_target(self) -> None: self._target_policy.soft_update(self._policy, self._soft_update_coef) self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef) - def to_device(self, device: str) -> None: + def to_device(self, device: str = None) -> None: self._device = get_torch_device(device=device) self._policy.to_device(self._device) self._target_policy.to_device(self._device) @@ -223,30 +212,49 @@ class DDPGTrainer(SingleAgentTrainer): https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg """ - def __init__(self, name: str, params: DDPGParams) -> None: - super(DDPGTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: DDPGParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(DDPGTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params self._policy_version = self._target_policy_version = 0 self._memory_size = 0 def build(self) -> None: - self._ops = self.get_ops() + self._ops = cast(DDPGOps, self.get_ops()) self._replay_memory = RandomReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._ops.policy_state_dim, action_dim=self._ops.policy_action_dim, random_overwrite=self._params.random_overwrite, ) + def _register_policy(self, policy: RLPolicy) -> None: + assert isinstance(policy, ContinuousRLPolicy) + self._policy = policy + def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: return transition_batch def get_local_ops(self) -> AbsTrainOps: return DDPGOps( - name=self._policy_name, - policy_creator=self._policy_creator, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + name=self._policy.name, + policy=self._policy, + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def _get_batch(self, batch_size: int = None) -> TransitionBatch: diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py index 9cdd46eeb..5a4f938ab 100644 --- a/maro/rl/training/algorithms/dqn.py +++ b/maro/rl/training/algorithms/dqn.py @@ -2,18 +2,18 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Callable, Dict +from typing import Dict, cast import torch from maro.rl.policy import RLPolicy, ValueBasedPolicy -from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone @dataclass -class DQNParams(TrainerParams): +class DQNParams(BaseTrainerParams): """ num_epochs (int, default=1): Number of training epochs. update_target_every (int, default=5): Number of gradient steps between target model updates. @@ -33,42 +33,34 @@ class DQNParams(TrainerParams): double: bool = False random_overwrite: bool = False - def extract_ops_params(self) -> Dict[str, object]: - return { - "reward_discount": self.reward_discount, - "soft_update_coef": self.soft_update_coef, - "double": self.double, - } - class DQNOps(AbsTrainOps): def __init__( self, name: str, - policy_creator: Callable[[], RLPolicy], - parallelism: int = 1, + policy: RLPolicy, + params: DQNParams, reward_discount: float = 0.9, - soft_update_coef: float = 0.1, - double: bool = False, + parallelism: int = 1, ) -> None: super(DQNOps, self).__init__( name=name, - policy_creator=policy_creator, + policy=policy, parallelism=parallelism, ) assert isinstance(self._policy, ValueBasedPolicy) self._reward_discount = reward_discount - self._soft_update_coef = soft_update_coef - self._double = double + self._soft_update_coef = params.soft_update_coef + self._double = params.double self._loss_func = torch.nn.MSELoss() self._target_policy: ValueBasedPolicy = clone(self._policy) self._target_policy.set_name(f"target_{self._policy.name}") self._target_policy.eval() - def _get_batch_loss(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.Tensor]]: + def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor: """Compute the loss of the batch. Args: @@ -78,6 +70,8 @@ def _get_batch_loss(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.T loss (torch.Tensor): The loss of the batch. """ assert isinstance(batch, TransitionBatch) + assert isinstance(self._policy, ValueBasedPolicy) + self._policy.train() states = ndarray_to_tensor(batch.states, device=self._device) next_states = ndarray_to_tensor(batch.next_states, device=self._device) @@ -100,7 +94,7 @@ def _get_batch_loss(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.T return self._loss_func(q_values, target_q_values) @remote - def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.Tensor]]: + def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: """Compute the network's gradients of a batch. Args: @@ -141,7 +135,7 @@ def soft_update_target(self) -> None: """Soft update the target policy.""" self._target_policy.soft_update(self._policy, self._soft_update_coef) - def to_device(self, device: str) -> None: + def to_device(self, device: str = None) -> None: self._device = get_torch_device(device) self._policy.to_device(self._device) self._target_policy.to_device(self._device) @@ -153,29 +147,48 @@ class DQNTrainer(SingleAgentTrainer): See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details. """ - def __init__(self, name: str, params: DQNParams) -> None: - super(DQNTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: DQNParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(DQNTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params self._q_net_version = self._target_q_net_version = 0 def build(self) -> None: - self._ops = self.get_ops() + self._ops = cast(DQNOps, self.get_ops()) self._replay_memory = RandomReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._ops.policy_state_dim, action_dim=self._ops.policy_action_dim, random_overwrite=self._params.random_overwrite, ) + def _register_policy(self, policy: RLPolicy) -> None: + assert isinstance(policy, ValueBasedPolicy) + self._policy = policy + def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: return transition_batch def get_local_ops(self) -> AbsTrainOps: return DQNOps( - name=self._policy_name, - policy_creator=self._policy_creator, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + name=self._policy.name, + policy=self._policy, + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def _get_batch(self, batch_size: int = None) -> TransitionBatch: diff --git a/maro/rl/training/algorithms/maddpg.py b/maro/rl/training/algorithms/maddpg.py index cf19e8627..edc63f39a 100644 --- a/maro/rl/training/algorithms/maddpg.py +++ b/maro/rl/training/algorithms/maddpg.py @@ -4,7 +4,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -12,14 +12,21 @@ from maro.rl.model import MultiQNet from maro.rl.policy import DiscretePolicyGradient, RLPolicy from maro.rl.rollout import ExpElement -from maro.rl.training import AbsTrainOps, MultiAgentTrainer, RandomMultiReplayMemory, RemoteOps, TrainerParams, remote +from maro.rl.training import ( + AbsTrainOps, + BaseTrainerParams, + MultiAgentTrainer, + RandomMultiReplayMemory, + RemoteOps, + remote, +) from maro.rl.utils import MultiTransitionBatch, get_torch_device, ndarray_to_tensor from maro.rl.utils.objects import FILE_SUFFIX from maro.utils import clone @dataclass -class DiscreteMADDPGParams(TrainerParams): +class DiscreteMADDPGParams(BaseTrainerParams): """ get_q_critic_net_func (Callable[[], MultiQNet]): Function to get multi Q critic net. num_epochs (int, default=10): Number of training epochs. @@ -30,44 +37,28 @@ class DiscreteMADDPGParams(TrainerParams): shared_critic (bool, default=False): Whether different policies use shared critic or individual policies. """ - get_q_critic_net_func: Callable[[], MultiQNet] = None + get_q_critic_net_func: Callable[[], MultiQNet] num_epoch: int = 10 update_target_every: int = 5 soft_update_coef: float = 0.5 - q_value_loss_cls: Callable = None + q_value_loss_cls: Optional[Callable] = None shared_critic: bool = False - def __post_init__(self) -> None: - assert self.get_q_critic_net_func is not None - - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_q_critic_net_func": self.get_q_critic_net_func, - "shared_critic": self.shared_critic, - "reward_discount": self.reward_discount, - "soft_update_coef": self.soft_update_coef, - "update_target_every": self.update_target_every, - "q_value_loss_func": self.q_value_loss_cls() if self.q_value_loss_cls is not None else torch.nn.MSELoss(), - } - class DiscreteMADDPGOps(AbsTrainOps): def __init__( self, name: str, - policy_creator: Callable[[], RLPolicy], - get_q_critic_net_func: Callable[[], MultiQNet], + policy: RLPolicy, + param: DiscreteMADDPGParams, + shared_critic: bool, policy_idx: int, parallelism: int = 1, - shared_critic: bool = False, reward_discount: float = 0.9, - soft_update_coef: float = 0.5, - update_target_every: int = 5, - q_value_loss_func: Callable = None, ) -> None: super(DiscreteMADDPGOps, self).__init__( name=name, - policy_creator=policy_creator, + policy=policy, parallelism=parallelism, ) @@ -75,23 +66,21 @@ def __init__( self._shared_critic = shared_critic # Actor - if self._policy_creator: + if self._policy: assert isinstance(self._policy, DiscretePolicyGradient) self._target_policy: DiscretePolicyGradient = clone(self._policy) self._target_policy.set_name(f"target_{self._policy.name}") self._target_policy.eval() # Critic - self._q_critic_net: MultiQNet = get_q_critic_net_func() + self._q_critic_net: MultiQNet = param.get_q_critic_net_func() self._target_q_critic_net: MultiQNet = clone(self._q_critic_net) self._target_q_critic_net.eval() self._reward_discount = reward_discount - self._q_value_loss_func = q_value_loss_func - self._update_target_every = update_target_every - self._soft_update_coef = soft_update_coef - - self._device = None + self._q_value_loss_func = param.q_value_loss_cls() if param.q_value_loss_cls is not None else torch.nn.MSELoss() + self._update_target_every = param.update_target_every + self._soft_update_coef = param.soft_update_coef def get_target_action(self, batch: MultiTransitionBatch) -> torch.Tensor: """Get the target policies' actions according to the batch. @@ -248,7 +237,7 @@ def update_actor_with_grad(self, grad_dict: dict) -> None: def soft_update_target(self) -> None: """Soft update the target policies and target critics.""" - if self._policy_creator: + if self._policy: self._target_policy.soft_update(self._policy, self._soft_update_coef) if not self._shared_critic: self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef) @@ -264,13 +253,13 @@ def set_critic_state(self, ops_state_dict: dict) -> None: self._target_q_critic_net.set_state(ops_state_dict["target_critic"]) def get_actor_state(self) -> dict: - if self._policy_creator: + if self._policy: return {"policy": self._policy.get_state(), "target_policy": self._target_policy.get_state()} else: return {} def set_actor_state(self, ops_state_dict: dict) -> None: - if self._policy_creator: + if self._policy: self._policy.set_state(ops_state_dict["policy"]) self._target_policy.set_state(ops_state_dict["target_policy"]) @@ -280,9 +269,9 @@ def get_non_policy_state(self) -> dict: def set_non_policy_state(self, state: dict) -> None: self.set_critic_state(state) - def to_device(self, device: str) -> None: + def to_device(self, device: str = None) -> None: self._device = get_torch_device(device) - if self._policy_creator: + if self._policy: self._policy.to_device(self._device) self._target_policy.to_device(self._device) @@ -296,31 +285,51 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer): See https://arxiv.org/abs/1706.02275 for details. """ - def __init__(self, name: str, params: DiscreteMADDPGParams) -> None: - super(DiscreteMADDPGTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: DiscreteMADDPGParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(DiscreteMADDPGTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params - self._ops_params = self._params.extract_ops_params() + self._state_dim = params.get_q_critic_net_func().state_dim self._policy_version = self._target_policy_version = 0 self._shared_critic_ops_name = f"{self._name}.shared_critic" - self._actor_ops_list = [] - self._critic_ops = None - self._replay_memory = None - self._policy2agent = {} + self._actor_ops_list: List[DiscreteMADDPGOps] = [] + self._critic_ops: Optional[DiscreteMADDPGOps] = None + self._policy2agent: Dict[str, str] = {} + self._ops_dict: Dict[str, DiscreteMADDPGOps] = {} def build(self) -> None: - for policy_name in self._policy_creator: - self._ops_dict[policy_name] = self.get_ops(policy_name) + self._placeholder_policy = self._policy_dict[self._policy_names[0]] + + for policy in self._policy_dict.values(): + self._ops_dict[policy.name] = cast(DiscreteMADDPGOps, self.get_ops(policy.name)) self._actor_ops_list = list(self._ops_dict.values()) if self._params.shared_critic: - self._ops_dict[self._shared_critic_ops_name] = self.get_ops(self._shared_critic_ops_name) + assert self._critic_ops is not None + self._ops_dict[self._shared_critic_ops_name] = cast( + DiscreteMADDPGOps, + self.get_ops(self._shared_critic_ops_name), + ) self._critic_ops = self._ops_dict[self._shared_critic_ops_name] self._replay_memory = RandomMultiReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._state_dim, action_dims=[ops.policy_action_dim for ops in self._actor_ops_list], agent_states_dims=[ops.policy_state_dim for ops in self._actor_ops_list], @@ -342,7 +351,7 @@ def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: rewards: List[np.ndarray] = [] agent_states: List[np.ndarray] = [] next_agent_states: List[np.ndarray] = [] - for policy_name in self._policy_names: + for policy_name in self._policy_dict: agent_name = self._policy2agent[policy_name] actions.append(np.vstack([exp_element.action_dict[agent_name] for exp_element in exp_elements])) rewards.append(np.array([exp_element.reward_dict[agent_name] for exp_element in exp_elements])) @@ -374,23 +383,25 @@ def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: def get_local_ops(self, name: str) -> AbsTrainOps: if name == self._shared_critic_ops_name: - ops_params = dict(self._ops_params) - ops_params.update( - { - "policy_idx": -1, - "shared_critic": False, - }, + return DiscreteMADDPGOps( + name=name, + policy=self._placeholder_policy, + param=self._params, + shared_critic=False, + policy_idx=-1, + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, ) - return DiscreteMADDPGOps(name=name, **ops_params) else: - ops_params = dict(self._ops_params) - ops_params.update( - { - "policy_creator": self._policy_creator[name], - "policy_idx": self._policy_names.index(name), - }, + return DiscreteMADDPGOps( + name=name, + policy=self._policy_dict[name], + param=self._params, + shared_critic=self._params.shared_critic, + policy_idx=self._policy_names.index(name), + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, ) - return DiscreteMADDPGOps(name=name, **ops_params) def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch: return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size) @@ -405,6 +416,7 @@ def train_step(self) -> None: # Update critic if self._params.shared_critic: + assert self._critic_ops is not None self._critic_ops.update_critic(batch, next_actions) critic_state_dict = self._critic_ops.get_critic_state() # Sync latest critic to ops @@ -431,6 +443,7 @@ async def train_step_as_task(self) -> None: # Update critic if self._params.shared_critic: + assert self._critic_ops is not None critic_grad = await asyncio.gather(*[self._critic_ops.get_critic_grad(batch, next_actions)]) assert isinstance(critic_grad, list) and isinstance(critic_grad[0], dict) self._critic_ops.update_critic_with_grad(critic_grad[0]) @@ -460,10 +473,11 @@ def _try_soft_update_target(self) -> None: for ops in self._actor_ops_list: ops.soft_update_target() if self._params.shared_critic: + assert self._critic_ops is not None self._critic_ops.soft_update_target() self._target_policy_version = self._policy_version - def get_policy_state(self) -> Dict[str, object]: + def get_policy_state(self) -> Dict[str, dict]: self._assert_ops_exists() ret_policy_state = {} for ops in self._actor_ops_list: @@ -484,6 +498,7 @@ def save(self, path: str) -> None: trainer_state = {ops.name: ops.get_state() for ops in self._actor_ops_list} if self._params.shared_critic: + assert self._critic_ops is not None trainer_state[self._critic_ops.name] = self._critic_ops.get_state() policy_state_dict = {ops_name: state["policy"] for ops_name, state in trainer_state.items()} diff --git a/maro/rl/training/algorithms/ppo.py b/maro/rl/training/algorithms/ppo.py index 417c12576..7abe089ef 100644 --- a/maro/rl/training/algorithms/ppo.py +++ b/maro/rl/training/algorithms/ppo.py @@ -2,16 +2,16 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Callable, Dict, Tuple +from typing import Tuple import numpy as np import torch from torch.distributions import Categorical -from maro.rl.model import VNet from maro.rl.policy import DiscretePolicyGradient, RLPolicy from maro.rl.training.algorithms.base import ACBasedOps, ACBasedParams, ACBasedTrainer from maro.rl.utils import TransitionBatch, discount_cumsum, ndarray_to_tensor +from maro.utils import clone @dataclass @@ -23,21 +23,7 @@ class PPOParams(ACBasedParams): If it is None, the actor loss is calculated using the usual policy gradient theorem. """ - clip_ratio: float = None - - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_v_critic_net_func": self.get_v_critic_net_func, - "reward_discount": self.reward_discount, - "critic_loss_cls": self.critic_loss_cls, - "clip_ratio": self.clip_ratio, - "lam": self.lam, - "min_logp": self.min_logp, - "is_discrete_action": self.is_discrete_action, - } - def __post_init__(self) -> None: - assert self.get_v_critic_net_func is not None assert self.clip_ratio is not None @@ -45,31 +31,20 @@ class DiscretePPOWithEntropyOps(ACBasedOps): def __init__( self, name: str, - policy_creator: Callable[[], RLPolicy], - get_v_critic_net_func: Callable[[], VNet], + policy: RLPolicy, + params: ACBasedParams, parallelism: int = 1, reward_discount: float = 0.9, - critic_loss_cls: Callable = None, - clip_ratio: float = None, - lam: float = 0.9, - min_logp: float = None, - is_discrete_action: bool = True, ) -> None: super(DiscretePPOWithEntropyOps, self).__init__( - name=name, - policy_creator=policy_creator, - get_v_critic_net_func=get_v_critic_net_func, - parallelism=parallelism, - reward_discount=reward_discount, - critic_loss_cls=critic_loss_cls, - clip_ratio=clip_ratio, - lam=lam, - min_logp=min_logp, - is_discrete_action=is_discrete_action, + name, + policy, + params, + reward_discount, + parallelism, ) - assert is_discrete_action - assert isinstance(self._policy, DiscretePolicyGradient) - self._policy_old = self._policy_creator() + assert self._is_discrete_action + self._policy_old: DiscretePolicyGradient = clone(policy) self.update_policy_old() def update_policy_old(self) -> None: @@ -172,8 +147,23 @@ class PPOTrainer(ACBasedTrainer): https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ppo. """ - def __init__(self, name: str, params: PPOParams) -> None: - super(PPOTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: PPOParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(PPOTrainer, self).__init__( + name, + params, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) class DiscretePPOWithEntropyTrainer(ACBasedTrainer): @@ -182,10 +172,11 @@ def __init__(self, name: str, params: PPOParams) -> None: def get_local_ops(self) -> DiscretePPOWithEntropyOps: return DiscretePPOWithEntropyOps( - name=self._policy_name, - policy_creator=self._policy_creator, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + name=self._policy.name, + policy=self._policy, + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def train_step(self) -> None: diff --git a/maro/rl/training/algorithms/sac.py b/maro/rl/training/algorithms/sac.py index d5bcdfc24..338addf57 100644 --- a/maro/rl/training/algorithms/sac.py +++ b/maro/rl/training/algorithms/sac.py @@ -2,73 +2,59 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, cast import torch from maro.rl.model import QNet from maro.rl.policy import ContinuousRLPolicy, RLPolicy -from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone @dataclass -class SoftActorCriticParams(TrainerParams): - get_q_critic_net_func: Callable[[], QNet] = None +class SoftActorCriticParams(BaseTrainerParams): + get_q_critic_net_func: Callable[[], QNet] update_target_every: int = 5 random_overwrite: bool = False entropy_coef: float = 0.1 num_epochs: int = 1 n_start_train: int = 0 - q_value_loss_cls: Callable = None + q_value_loss_cls: Optional[Callable] = None soft_update_coef: float = 1.0 - def __post_init__(self) -> None: - assert self.get_q_critic_net_func is not None - - def extract_ops_params(self) -> Dict[str, object]: - return { - "get_q_critic_net_func": self.get_q_critic_net_func, - "entropy_coef": self.entropy_coef, - "reward_discount": self.reward_discount, - "q_value_loss_cls": self.q_value_loss_cls, - "soft_update_coef": self.soft_update_coef, - } - class SoftActorCriticOps(AbsTrainOps): def __init__( self, name: str, - policy_creator: Callable[[], RLPolicy], - get_q_critic_net_func: Callable[[], QNet], + policy: RLPolicy, + params: SoftActorCriticParams, + reward_discount: float = 0.9, parallelism: int = 1, - *, - entropy_coef: float, - reward_discount: float, - q_value_loss_cls: Callable = None, - soft_update_coef: float = 1.0, ) -> None: super(SoftActorCriticOps, self).__init__( name=name, - policy_creator=policy_creator, + policy=policy, parallelism=parallelism, ) assert isinstance(self._policy, ContinuousRLPolicy) - self._q_net1 = get_q_critic_net_func() - self._q_net2 = get_q_critic_net_func() + self._q_net1 = params.get_q_critic_net_func() + self._q_net2 = params.get_q_critic_net_func() self._target_q_net1: QNet = clone(self._q_net1) self._target_q_net1.eval() self._target_q_net2: QNet = clone(self._q_net2) self._target_q_net2.eval() - self._entropy_coef = entropy_coef - self._soft_update_coef = soft_update_coef + self._entropy_coef = params.entropy_coef + self._soft_update_coef = params.soft_update_coef self._reward_discount = reward_discount - self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() + self._q_value_loss_func = ( + params.q_value_loss_cls() if params.q_value_loss_cls is not None else torch.nn.MSELoss() + ) def _get_critic_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, torch.Tensor]: self._q_net1.train() @@ -100,11 +86,11 @@ def get_critic_grad(self, batch: TransitionBatch) -> Tuple[Dict[str, torch.Tenso grad_q2 = self._q_net2.get_gradients(loss_q2) return grad_q1, grad_q2 - def update_critic_with_grad(self, grad_dict1: dict, grad_dict2: dict) -> None: + def update_critic_with_grad(self, grad_dicts: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]) -> None: self._q_net1.train() self._q_net2.train() - self._q_net1.apply_gradients(grad_dict1) - self._q_net2.apply_gradients(grad_dict2) + self._q_net1.apply_gradients(grad_dicts[0]) + self._q_net2.apply_gradients(grad_dicts[1]) def update_critic(self, batch: TransitionBatch) -> None: self._q_net1.train() @@ -154,7 +140,7 @@ def soft_update_target(self) -> None: self._target_q_net1.soft_update(self._q_net1, self._soft_update_coef) self._target_q_net2.soft_update(self._q_net2, self._soft_update_coef) - def to_device(self, device: str) -> None: + def to_device(self, device: str = None) -> None: self._device = get_torch_device(device=device) self._q_net1.to(self._device) self._q_net2.to(self._device) @@ -163,22 +149,38 @@ def to_device(self, device: str) -> None: class SoftActorCriticTrainer(SingleAgentTrainer): - def __init__(self, name: str, params: SoftActorCriticParams) -> None: - super(SoftActorCriticTrainer, self).__init__(name, params) + def __init__( + self, + name: str, + params: SoftActorCriticParams, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(SoftActorCriticTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) self._params = params self._qnet_version = self._target_qnet_version = 0 - self._replay_memory: Optional[RandomReplayMemory] = None - def build(self) -> None: - self._ops = self.get_ops() + self._ops = cast(SoftActorCriticOps, self.get_ops()) self._replay_memory = RandomReplayMemory( - capacity=self._params.replay_memory_capacity, + capacity=self._replay_memory_capacity, state_dim=self._ops.policy_state_dim, action_dim=self._ops.policy_action_dim, random_overwrite=self._params.random_overwrite, ) + def _register_policy(self, policy: RLPolicy) -> None: + assert isinstance(policy, ContinuousRLPolicy) + self._policy = policy + def train_step(self) -> None: assert isinstance(self._ops, SoftActorCriticOps) @@ -218,10 +220,11 @@ def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatc def get_local_ops(self) -> SoftActorCriticOps: return SoftActorCriticOps( - name=self._policy_name, - policy_creator=self._policy_creator, - parallelism=self._params.data_parallelism, - **self._params.extract_ops_params(), + name=self._policy.name, + policy=self._policy, + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, ) def _get_batch(self, batch_size: int = None) -> TransitionBatch: diff --git a/maro/rl/training/proxy.py b/maro/rl/training/proxy.py index 29eaaed7a..04f1af849 100644 --- a/maro/rl/training/proxy.py +++ b/maro/rl/training/proxy.py @@ -2,8 +2,9 @@ # Licensed under the MIT license. from collections import defaultdict, deque +from typing import Deque -from maro.rl.distributed import AbsProxy +from maro.rl.distributed import DEFAULT_TRAINING_BACKEND_PORT, DEFAULT_TRAINING_FRONTEND_PORT, AbsProxy from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes from maro.rl.utils.torch_utils import average_grads from maro.utils import LoggerV2 @@ -20,13 +21,16 @@ class TrainingProxy(AbsProxy): backend_port (int, default=10001): Network port for communicating with back-end workers (task consumers). """ - def __init__(self, frontend_port: int = 10000, backend_port: int = 10001) -> None: - super(TrainingProxy, self).__init__(frontend_port=frontend_port, backend_port=backend_port) - self._available_workers = deque() - self._worker_ready = False - self._connected_ops = set() - self._result_cache = defaultdict(list) - self._expected_num_results = {} + def __init__(self, frontend_port: int = None, backend_port: int = None) -> None: + super(TrainingProxy, self).__init__( + frontend_port=frontend_port if frontend_port is not None else DEFAULT_TRAINING_FRONTEND_PORT, + backend_port=backend_port if backend_port is not None else DEFAULT_TRAINING_BACKEND_PORT, + ) + self._available_workers: Deque = deque() + self._worker_ready: bool = False + self._connected_ops: set = set() + self._result_cache: dict = defaultdict(list) + self._expected_num_results: dict = {} self._logger = LoggerV2("TRAIN-PROXY") def _route_request_to_compute_node(self, msg: list) -> None: @@ -48,10 +52,12 @@ def _route_request_to_compute_node(self, msg: list) -> None: self._connected_ops.add(msg[0]) req = bytes_to_pyobj(msg[-1]) + assert isinstance(req, dict) + desired_parallelism = req["desired_parallelism"] req["args"] = list(req["args"]) batch = req["args"][0] - workers = [] + workers: list = [] while len(workers) < desired_parallelism and self._available_workers: workers.append(self._available_workers.popleft()) diff --git a/maro/rl/training/train_ops.py b/maro/rl/training/train_ops.py index 869364e81..57888038a 100644 --- a/maro/rl/training/train_ops.py +++ b/maro/rl/training/train_ops.py @@ -3,8 +3,9 @@ import inspect from abc import ABCMeta, abstractmethod -from typing import Callable, Tuple +from typing import Any, Callable, Optional, Tuple, Union +import torch import zmq from zmq.asyncio import Context, Poller @@ -19,24 +20,21 @@ class AbsTrainOps(object, metaclass=ABCMeta): Args: name (str): Name of the ops. This is usually a policy name. - policy_creator (Callable[[], RLPolicy]): Function to create a policy instance. + policy (RLPolicy): Policy instance. parallelism (int, default=1): Desired degree of data parallelism. """ def __init__( self, name: str, - policy_creator: Callable[[], RLPolicy], + policy: RLPolicy, parallelism: int = 1, ) -> None: super(AbsTrainOps, self).__init__() self._name = name - self._policy_creator = policy_creator - # Create the policy. - if self._policy_creator: - self._policy = self._policy_creator() - + self._policy = policy self._parallelism = parallelism + self._device: Optional[torch.device] = None @property def name(self) -> str: @@ -44,11 +42,11 @@ def name(self) -> str: @property def policy_state_dim(self) -> int: - return self._policy.state_dim if self._policy_creator else None + return self._policy.state_dim @property def policy_action_dim(self) -> int: - return self._policy.action_dim if self._policy_creator else None + return self._policy.action_dim @property def parallelism(self) -> int: @@ -75,20 +73,20 @@ def set_state(self, ops_state_dict: dict) -> None: self.set_policy_state(ops_state_dict["policy"][1]) self.set_non_policy_state(ops_state_dict["non_policy"]) - def get_policy_state(self) -> Tuple[str, object]: + def get_policy_state(self) -> Tuple[str, dict]: """Get the policy's state. Returns: policy_name (str) - policy_state (object) + policy_state (Any) """ return self._policy.name, self._policy.get_state() - def set_policy_state(self, policy_state: object) -> None: + def set_policy_state(self, policy_state: dict) -> None: """Update the policy's state. Args: - policy_state (object): The policy state. + policy_state (dict): The policy state. """ self._policy.set_state(policy_state) @@ -111,17 +109,17 @@ def set_non_policy_state(self, state: dict) -> None: raise NotImplementedError @abstractmethod - def to_device(self, device: str): + def to_device(self, device: str = None) -> None: raise NotImplementedError -def remote(func) -> Callable: +def remote(func: Callable) -> Callable: """Annotation to indicate that a function / method can be called remotely. This annotation takes effect only when an ``AbsTrainOps`` object is wrapped by a ``RemoteOps``. """ - def remote_annotate(*args, **kwargs) -> object: + def remote_annotate(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) return remote_annotate @@ -137,7 +135,7 @@ class AsyncClient(object): """ def __init__(self, name: str, address: Tuple[str, int], logger: LoggerV2 = None) -> None: - self._logger = DummyLogger() if logger is None else logger + self._logger: Union[LoggerV2, DummyLogger] = logger if logger is not None else DummyLogger() self._name = name host, port = address self._proxy_ip = get_ip_address_by_hostname(host) @@ -155,7 +153,7 @@ async def send_request(self, req: dict) -> None: await self._socket.send(pyobj_to_bytes(req)) self._logger.debug(f"{self._name} sent request {req['func']}") - async def get_response(self) -> object: + async def get_response(self) -> Any: """Waits for a result in asynchronous fashion. This is a coroutine and is executed asynchronously with calls to other AsyncClients' ``get_response`` calls. @@ -209,15 +207,15 @@ def __init__(self, ops: AbsTrainOps, address: Tuple[str, int], logger: LoggerV2 self._client = AsyncClient(self._ops.name, address, logger=logger) self._client.connect() - def __getattribute__(self, attr_name: str) -> object: + def __getattribute__(self, attr_name: str) -> Any: # Ignore methods that belong to the parent class try: return super().__getattribute__(attr_name) except AttributeError: pass - def remote_method(ops_state, func_name: str, desired_parallelism: int, client: AsyncClient) -> Callable: - async def remote_call(*args, **kwargs) -> object: + def remote_method(ops_state: Any, func_name: str, desired_parallelism: int, client: AsyncClient) -> Callable: + async def remote_call(*args: Any, **kwargs: Any) -> Any: req = { "state": ops_state, "func": func_name, diff --git a/maro/rl/training/trainer.py b/maro/rl/training/trainer.py index a34517ac5..8bced5674 100644 --- a/maro/rl/training/trainer.py +++ b/maro/rl/training/trainer.py @@ -5,7 +5,7 @@ import os from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -21,37 +21,8 @@ @dataclass -class TrainerParams: - """Common trainer parameters. - - replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory. - batch_size (int, default=128): Training batch size. - data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when - a model is large and computing gradients with respect to a batch becomes expensive. In this case, the - batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set - of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets - updated only after collecting all the gradients from the remote nodes. Note that this value is the desired - parallelism and the actual parallelism in a distributed experiment may be smaller depending on the - availability of compute resources. For details on distributed deep learning and data parallelism, see - https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an abundance - of resources available on the internet. - reward_discount (float, default=0.9): Reward decay as defined in standard RL terminology. - - """ - - replay_memory_capacity: int = 10000 - batch_size: int = 128 - data_parallelism: int = 1 - reward_discount: float = 0.9 - - @abstractmethod - def extract_ops_params(self) -> Dict[str, object]: - """Extract parameters that should be passed to the train ops. - - Returns: - params (Dict[str, object]): Parameter dict. - """ - raise NotImplementedError +class BaseTrainerParams: + pass class AbsTrainer(object, metaclass=ABCMeta): @@ -64,16 +35,36 @@ class AbsTrainer(object, metaclass=ABCMeta): Args: name (str): Name of the trainer. - params (TrainerParams): Trainer's parameters. + replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory. + batch_size (int, default=128): Training batch size. + data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when + a model is large and computing gradients with respect to a batch becomes expensive. In this case, the + batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set + of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets + updated only after collecting all the gradients from the remote nodes. Note that this value is the desired + parallelism and the actual parallelism in a distributed experiment may be smaller depending on the + availability of compute resources. For details on distributed deep learning and data parallelism, see + https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an + abundance of resources available on the internet. + reward_discount (float, default=0.9): Reward decay as defined in standard RL terminology. """ - def __init__(self, name: str, params: TrainerParams) -> None: + def __init__( + self, + name: str, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: self._name = name - self._params = params - self._batch_size = self._params.batch_size + self._replay_memory_capacity = replay_memory_capacity + self._batch_size = batch_size + self._data_parallelism = data_parallelism + self._reward_discount = reward_discount + self._agent2policy: Dict[Any, str] = {} self._proxy_address: Optional[Tuple[str, int]] = None - self._logger = None @property def name(self) -> str: @@ -83,13 +74,11 @@ def name(self) -> str: def agent_num(self) -> int: return len(self._agent2policy) - def register_logger(self, logger: LoggerV2) -> None: + def register_logger(self, logger: LoggerV2 = None) -> None: self._logger = logger def register_agent2policy(self, agent2policy: Dict[Any, str], policy_trainer_mapping: Dict[str, str]) -> None: - """Register the agent to policy dict that correspond to the current trainer. A valid policy name should start - with the name of its trainer. For example, "DQN.POLICY_NAME". Therefore, we could identify which policies - should be registered to the current trainer according to the policy's name. + """Register the agent to policy dict that correspond to the current trainer. Args: agent2policy (Dict[Any, str]): Agent name to policy name mapping. @@ -102,16 +91,11 @@ def register_agent2policy(self, agent2policy: Dict[Any, str], policy_trainer_map } @abstractmethod - def register_policy_creator( - self, - global_policy_creator: Dict[str, Callable[[], AbsPolicy]], - policy_trainer_mapping: Dict[str, str], - ) -> None: - """Register the policy creator. Only keep the creators of the policies that the current trainer need to train. + def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None: + """Register the policies. Only keep the creators of the policies that the current trainer need to train. Args: - global_policy_creator (Dict[str, Callable[[], AbsPolicy]]): Dict that contains the creators for all - policies. + policies (List[AbsPolicy]): All policies. policy_trainer_mapping (Dict[str, str]): Policy name to trainer name mapping. """ raise NotImplementedError @@ -147,7 +131,7 @@ def set_proxy_address(self, proxy_address: Tuple[str, int]) -> None: self._proxy_address = proxy_address @abstractmethod - def get_policy_state(self) -> Dict[str, object]: + def get_policy_state(self) -> Dict[str, dict]: """Get policies' states. Returns: @@ -171,30 +155,46 @@ async def exit(self) -> None: class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta): """Policy trainer that trains only one policy.""" - def __init__(self, name: str, params: TrainerParams) -> None: - super(SingleAgentTrainer, self).__init__(name, params) - self._policy_name: Optional[str] = None - self._policy_creator: Optional[Callable[[], RLPolicy]] = None - self._ops: Optional[AbsTrainOps] = None - self._replay_memory: Optional[ReplayMemory] = None + def __init__( + self, + name: str, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, + ) -> None: + super(SingleAgentTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) @property - def ops(self): - return self._ops + def ops(self) -> Union[AbsTrainOps, RemoteOps]: + ops = getattr(self, "_ops", None) + assert isinstance(ops, (AbsTrainOps, RemoteOps)) + return ops - def register_policy_creator( - self, - global_policy_creator: Dict[str, Callable[[], AbsPolicy]], - policy_trainer_mapping: Dict[str, str], - ) -> None: - policy_names = [ - policy_name for policy_name in global_policy_creator if policy_trainer_mapping[policy_name] == self.name - ] - if len(policy_names) != 1: + @property + def replay_memory(self) -> ReplayMemory: + replay_memory = getattr(self, "_replay_memory", None) + assert isinstance(replay_memory, ReplayMemory), "Replay memory is required." + return replay_memory + + def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None: + policies = [policy for policy in policies if policy_trainer_mapping[policy.name] == self.name] + if len(policies) != 1: raise ValueError(f"Trainer {self._name} should have exactly one policy assigned to it") - self._policy_name = policy_names.pop() - self._policy_creator = global_policy_creator[self._policy_name] + policy = policies.pop() + assert isinstance(policy, RLPolicy) + self._register_policy(policy) + + @abstractmethod + def _register_policy(self, policy: RLPolicy) -> None: + raise NotImplementedError @abstractmethod def get_local_ops(self) -> AbsTrainOps: @@ -216,9 +216,9 @@ def get_ops(self) -> Union[RemoteOps, AbsTrainOps]: ops = self.get_local_ops() return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops - def get_policy_state(self) -> Dict[str, object]: + def get_policy_state(self) -> Dict[str, dict]: self._assert_ops_exists() - policy_name, state = self._ops.get_policy_state() + policy_name, state = self.ops.get_policy_state() return {policy_name: state} def load(self, path: str) -> None: @@ -227,7 +227,7 @@ def load(self, path: str) -> None: policy_state = torch.load(os.path.join(path, f"{self.name}_policy.{FILE_SUFFIX}")) non_policy_state = torch.load(os.path.join(path, f"{self.name}_non_policy.{FILE_SUFFIX}")) - self._ops.set_state( + self.ops.set_state( { "policy": policy_state, "non_policy": non_policy_state, @@ -237,7 +237,7 @@ def load(self, path: str) -> None: def save(self, path: str) -> None: self._assert_ops_exists() - ops_state = self._ops.get_state() + ops_state = self.ops.get_state() policy_state = ops_state["policy"] non_policy_state = ops_state["non_policy"] @@ -267,46 +267,57 @@ def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: next_states=np.vstack([exp[4] for exp in exps]), ) transition_batch = self._preprocess_batch(transition_batch) - self._replay_memory.put(transition_batch) + self.replay_memory.put(transition_batch) @abstractmethod def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: raise NotImplementedError def _assert_ops_exists(self) -> None: - if not self._ops: + if not self.ops: raise ValueError("'build' needs to be called to create an ops instance first.") async def exit(self) -> None: self._assert_ops_exists() - if isinstance(self._ops, RemoteOps): - await self._ops.exit() + ops = self.ops + if isinstance(ops, RemoteOps): + await ops.exit() class MultiAgentTrainer(AbsTrainer, metaclass=ABCMeta): """Policy trainer that trains multiple policies.""" - def __init__(self, name: str, params: TrainerParams) -> None: - super(MultiAgentTrainer, self).__init__(name, params) - self._policy_creator: Dict[str, Callable[[], RLPolicy]] = {} - self._policy_names: List[str] = [] - self._ops_dict: Dict[str, AbsTrainOps] = {} - - @property - def ops_dict(self): - return self._ops_dict - - def register_policy_creator( + def __init__( self, - global_policy_creator: Dict[str, Callable[[], AbsPolicy]], - policy_trainer_mapping: Dict[str, str], + name: str, + replay_memory_capacity: int = 10000, + batch_size: int = 128, + data_parallelism: int = 1, + reward_discount: float = 0.9, ) -> None: - self._policy_creator: Dict[str, Callable[[], RLPolicy]] = { - policy_name: func - for policy_name, func in global_policy_creator.items() - if policy_trainer_mapping[policy_name] == self.name - } - self._policy_names = list(self._policy_creator.keys()) + super(MultiAgentTrainer, self).__init__( + name, + replay_memory_capacity, + batch_size, + data_parallelism, + reward_discount, + ) + + @property + def ops_dict(self) -> Dict[str, AbsTrainOps]: + ops_dict = getattr(self, "_ops_dict", None) + assert isinstance(ops_dict, dict) + return ops_dict + + def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None: + self._policy_names: List[str] = [ + policy.name for policy in policies if policy_trainer_mapping[policy.name] == self.name + ] + self._policy_dict: Dict[str, RLPolicy] = {} + for policy in policies: + if policy_trainer_mapping[policy.name] == self.name: + assert isinstance(policy, RLPolicy) + self._policy_dict[policy.name] = policy @abstractmethod def get_local_ops(self, name: str) -> AbsTrainOps: @@ -335,7 +346,7 @@ def get_ops(self, name: str) -> Union[RemoteOps, AbsTrainOps]: return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops @abstractmethod - def get_policy_state(self) -> Dict[str, object]: + def get_policy_state(self) -> Dict[str, dict]: raise NotImplementedError @abstractmethod diff --git a/maro/rl/training/training_manager.py b/maro/rl/training/training_manager.py index 9136fdeed..9d6b36b15 100644 --- a/maro/rl/training/training_manager.py +++ b/maro/rl/training/training_manager.py @@ -7,7 +7,6 @@ import collections import os import typing -from itertools import chain from typing import Any, Dict, Iterable, List, Tuple from maro.rl.rollout import ExpElement @@ -26,8 +25,8 @@ class TrainingManager(object): Training manager. Manage and schedule all trainers to train policies. Args: - rl_component_bundle (RLComponentBundle): The RL component bundle of the job. - explicit_assign_device (bool): Whether to assign policy to its device in the training manager. + rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow. + explicit_assign_device (bool, default=False): Whether to assign policy to its device in the training manager. proxy_address (Tuple[str, int], default=None): Address of the training proxy. If it is not None, it is registered to all trainers, which in turn create `RemoteOps` for distributed training. logger (LoggerV2, default=None): A logger for logging key events. @@ -36,36 +35,33 @@ class TrainingManager(object): def __init__( self, rl_component_bundle: RLComponentBundle, - explicit_assign_device: bool, + explicit_assign_device: bool = False, proxy_address: Tuple[str, int] = None, logger: LoggerV2 = None, ) -> None: super(TrainingManager, self).__init__() - self._trainer_dict: Dict[str, AbsTrainer] = {} self._proxy_address = proxy_address - for trainer_name, func in rl_component_bundle.trainer_creator.items(): - trainer = func() + + self._trainer_dict: Dict[str, AbsTrainer] = {} + for trainer in rl_component_bundle.trainers: if self._proxy_address: trainer.set_proxy_address(self._proxy_address) trainer.register_agent2policy( - rl_component_bundle.trainable_agent2policy, - rl_component_bundle.policy_trainer_mapping, + agent2policy=rl_component_bundle.trainable_agent2policy, + policy_trainer_mapping=rl_component_bundle.policy_trainer_mapping, ) - trainer.register_policy_creator( - rl_component_bundle.trainable_policy_creator, - rl_component_bundle.policy_trainer_mapping, + trainer.register_policies( + policies=rl_component_bundle.policies, + policy_trainer_mapping=rl_component_bundle.policy_trainer_mapping, ) trainer.register_logger(logger) - trainer.build() # `build()` must be called after `register_policy_creator()` - self._trainer_dict[trainer_name] = trainer + trainer.build() # `build()` must be called after `register_policies()` + self._trainer_dict[trainer.name] = trainer # User-defined allocation of compute devices, i.e., GPU's to the trainer ops if explicit_assign_device: for policy_name, device_name in rl_component_bundle.device_mapping.items(): - if policy_name not in rl_component_bundle.policy_trainer_mapping: # No need to assign device - continue - trainer = self._trainer_dict[rl_component_bundle.policy_trainer_mapping[policy_name]] if isinstance(trainer, SingleAgentTrainer): @@ -95,13 +91,16 @@ async def train_step() -> Iterable: for trainer in self._trainer_dict.values(): trainer.train_step() - def get_policy_state(self) -> Dict[str, Dict[str, object]]: + def get_policy_state(self) -> Dict[str, dict]: """Get policies' states. Returns: A double-deck dict with format: {trainer_name: {policy_name: policy_state}} """ - return dict(chain(*[trainer.get_policy_state().items() for trainer in self._trainer_dict.values()])) + policy_states: Dict[str, dict] = {} + for trainer in self._trainer_dict.values(): + policy_states.update(trainer.get_policy_state()) + return policy_states def record_experiences(self, experiences: List[List[ExpElement]]) -> None: """Record experiences collected from external modules (for example, EnvSampler). diff --git a/maro/rl/training/worker.py b/maro/rl/training/worker.py index 000f5973f..4cb1528f4 100644 --- a/maro/rl/training/worker.py +++ b/maro/rl/training/worker.py @@ -6,7 +6,7 @@ import typing from typing import Dict -from maro.rl.distributed import AbsWorker +from maro.rl.distributed import DEFAULT_TRAINING_BACKEND_PORT, AbsWorker from maro.rl.training import SingleAgentTrainer from maro.rl.utils.common import bytes_to_pyobj, bytes_to_string, pyobj_to_bytes from maro.utils import LoggerV2 @@ -24,7 +24,7 @@ class TrainOpsWorker(AbsWorker): Args: idx (int): Integer identifier for the worker. It is used to generate an internal ID, "worker.{idx}", so that the proxy can keep track of its connection status. - rl_component_bundle (RLComponentBundle): The RL component bundle of the job. + rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow. producer_host (str): IP address of the proxy host to connect to. producer_port (int, default=10001): Port of the proxy host to connect to. """ @@ -34,13 +34,13 @@ def __init__( idx: int, rl_component_bundle: RLComponentBundle, producer_host: str, - producer_port: int = 10001, + producer_port: int = None, logger: LoggerV2 = None, ) -> None: super(TrainOpsWorker, self).__init__( idx=idx, producer_host=producer_host, - producer_port=producer_port, + producer_port=producer_port if producer_port is not None else DEFAULT_TRAINING_BACKEND_PORT, logger=logger, ) @@ -62,13 +62,17 @@ def _compute(self, msg: list) -> None: ops_name, req = bytes_to_string(msg[0]), bytes_to_pyobj(msg[-1]) assert isinstance(req, dict) + trainer_dict: Dict[str, AbsTrainer] = { + trainer.name: trainer for trainer in self._rl_component_bundle.trainers + } + if ops_name not in self._ops_dict: - trainer_name = ops_name.split(".")[0] + trainer_name = self._rl_component_bundle.policy_trainer_mapping[ops_name] if trainer_name not in self._trainer_dict: - trainer = self._rl_component_bundle.trainer_creator[trainer_name]() - trainer.register_policy_creator( - self._rl_component_bundle.trainable_policy_creator, - self._rl_component_bundle.policy_trainer_mapping, + trainer = trainer_dict[trainer_name] + trainer.register_policies( + policies=self._rl_component_bundle.policies, + policy_trainer_mapping=self._rl_component_bundle.policy_trainer_mapping, ) self._trainer_dict[trainer_name] = trainer diff --git a/maro/rl/utils/common.py b/maro/rl/utils/common.py index e69b907b7..516239670 100644 --- a/maro/rl/utils/common.py +++ b/maro/rl/utils/common.py @@ -4,17 +4,17 @@ import os import pickle import socket -from typing import List, Optional +from typing import Any, List, Optional -def get_env(var_name: str, required: bool = True, default: object = None) -> str: +def get_env(var_name: str, required: bool = True, default: str = None) -> Optional[str]: """Wrapper for os.getenv() that includes a check for mandatory environment variables. Args: var_name (str): Variable name. required (bool, default=True): Flag indicating whether the environment variable in questions is required. If this is true and the environment variable is not present in ``os.environ``, a ``KeyError`` is raised. - default (object, default=None): Default value for the environment variable if it is missing in ``os.environ`` + default (str, default=None): Default value for the environment variable if it is missing in ``os.environ`` and ``required`` is false. Ignored if ``required`` is True. Returns: @@ -52,11 +52,11 @@ def bytes_to_string(bytes_: bytes) -> str: return bytes_.decode(DEFAULT_MSG_ENCODING) -def pyobj_to_bytes(pyobj) -> bytes: +def pyobj_to_bytes(pyobj: Any) -> bytes: return pickle.dumps(pyobj) -def bytes_to_pyobj(bytes_: bytes) -> object: +def bytes_to_pyobj(bytes_: bytes) -> Any: return pickle.loads(bytes_) diff --git a/maro/rl/utils/torch_utils.py b/maro/rl/utils/torch_utils.py index 3335fe24a..82476411f 100644 --- a/maro/rl/utils/torch_utils.py +++ b/maro/rl/utils/torch_utils.py @@ -55,5 +55,5 @@ def average_grads(grad_list: List[dict]) -> dict: } -def get_torch_device(device: str = None): +def get_torch_device(device: str = None) -> torch.device: return torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu")) diff --git a/maro/rl/workflows/config/parser.py b/maro/rl/workflows/config/parser.py index 68e8cbb30..db52f065a 100644 --- a/maro/rl/workflows/config/parser.py +++ b/maro/rl/workflows/config/parser.py @@ -207,7 +207,7 @@ def _validate_checkpointing_section(self, section: dict) -> None: f"{self._validation_err_pfx}: 'training.checkpointing.interval' must be an int", ) - def _validate_logging_section(self, component, level_dict: dict) -> None: + def _validate_logging_section(self, component: str, level_dict: dict) -> None: if any(key not in {"stdout", "file"} for key in level_dict): raise KeyError( f"{self._validation_err_pfx}: fields under section '{component}.logging' must be 'stdout' or 'file'", @@ -261,7 +261,7 @@ def get_job_spec(self, containerize: bool = False) -> Dict[str, Tuple[str, Dict[ num_episodes = self._config["main"]["num_episodes"] main_proc = f"{self._config['job']}.main" min_n_sample = self._config["main"].get("min_n_sample", 1) - env = { + env: dict = { main_proc: ( os.path.join(self._get_workflow_path(containerize=containerize), "main.py"), { diff --git a/maro/rl/workflows/main.py b/maro/rl/workflows/main.py index a46dd0a55..31de7caa1 100644 --- a/maro/rl/workflows/main.py +++ b/maro/rl/workflows/main.py @@ -6,116 +6,157 @@ import os import sys import time -from typing import List, Type +from typing import List, Union from maro.rl.rl_component.rl_component_bundle import RLComponentBundle -from maro.rl.rollout import BatchEnvSampler, ExpElement +from maro.rl.rollout import AbsEnvSampler, BatchEnvSampler, ExpElement from maro.rl.training import TrainingManager from maro.rl.utils import get_torch_device from maro.rl.utils.common import float_or_none, get_env, int_or_none, list_or_none from maro.rl.utils.training import get_latest_ep +from maro.rl.workflows.utils import env_str_helper from maro.utils import LoggerV2 -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="MARO RL workflow parser") - parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow") - return parser.parse_args() +class WorkflowEnvAttributes: + def __init__(self) -> None: + # Number of training episodes + self.num_episodes = int(env_str_helper(get_env("NUM_EPISODES"))) + # Maximum number of steps in on round of sampling. + self.num_steps = int_or_none(get_env("NUM_STEPS", required=False)) -def main(rl_component_bundle: RLComponentBundle, args: argparse.Namespace) -> None: - if args.evaluate_only: - evaluate_only_workflow(rl_component_bundle) - else: - training_workflow(rl_component_bundle) + # Minimum number of data samples to start a round of training. If the data samples are insufficient, re-run + # data sampling until we have at least `min_n_sample` data entries. + self.min_n_sample = int(env_str_helper(get_env("MIN_N_SAMPLE"))) + # Path to store logs. + self.log_path = get_env("LOG_PATH") -def training_workflow(rl_component_bundle: RLComponentBundle) -> None: - num_episodes = int(get_env("NUM_EPISODES")) - num_steps = int_or_none(get_env("NUM_STEPS", required=False)) - min_n_sample = int_or_none(get_env("MIN_N_SAMPLE")) + # Log levels + self.log_level_stdout = get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL") + self.log_level_file = get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL") - logger = LoggerV2( - "MAIN", - dump_path=get_env("LOG_PATH"), - dump_mode="a", - stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"), - file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"), - ) - logger.info("Start training workflow.") - - env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False)) - env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False)) - parallel_rollout = env_sampling_parallelism is not None or env_eval_parallelism is not None - train_mode = get_env("TRAIN_MODE") - - is_single_thread = train_mode == "simple" and not parallel_rollout - if is_single_thread: - rl_component_bundle.pre_create_policy_instances() - - if parallel_rollout: - env_sampler = BatchEnvSampler( - sampling_parallelism=env_sampling_parallelism, - port=int(get_env("ROLLOUT_CONTROLLER_PORT")), - min_env_samples=int_or_none(get_env("MIN_ENV_SAMPLES", required=False)), - grace_factor=float_or_none(get_env("GRACE_FACTOR", required=False)), - eval_parallelism=env_eval_parallelism, - logger=logger, + # Parallelism of sampling / evaluation. Used in distributed sampling. + self.env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False)) + self.env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False)) + + # Training mode, simple or distributed + self.train_mode = get_env("TRAIN_MODE") + + # Evaluating schedule. + self.eval_schedule = list_or_none(get_env("EVAL_SCHEDULE", required=False)) + + # Restore configurations. + self.load_path = get_env("LOAD_PATH", required=False) + self.load_episode = int_or_none(get_env("LOAD_EPISODE", required=False)) + + # Checkpointing configurations. + self.checkpoint_path = get_env("CHECKPOINT_PATH", required=False) + self.checkpoint_interval = int_or_none(get_env("CHECKPOINT_INTERVAL", required=False)) + + # Parallel sampling configurations. + self.parallel_rollout = self.env_sampling_parallelism is not None or self.env_eval_parallelism is not None + if self.parallel_rollout: + self.port = int(env_str_helper(get_env("ROLLOUT_CONTROLLER_PORT"))) + self.min_env_samples = int_or_none(get_env("MIN_ENV_SAMPLES", required=False)) + self.grace_factor = float_or_none(get_env("GRACE_FACTOR", required=False)) + + self.is_single_thread = self.train_mode == "simple" and not self.parallel_rollout + + # Distributed training configurations. + if self.train_mode != "simple": + self.proxy_address = ( + env_str_helper(get_env("TRAIN_PROXY_HOST")), + int(env_str_helper(get_env("TRAIN_PROXY_FRONTEND_PORT"))), + ) + + self.logger = LoggerV2( + "MAIN", + dump_path=self.log_path, + dump_mode="a", + stdout_level=self.log_level_stdout, + file_level=self.log_level_file, + ) + + +def _get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="MARO RL workflow parser") + parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow") + return parser.parse_args() + + +def _get_env_sampler( + rl_component_bundle: RLComponentBundle, + env_attr: WorkflowEnvAttributes, +) -> Union[AbsEnvSampler, BatchEnvSampler]: + if env_attr.parallel_rollout: + assert env_attr.env_sampling_parallelism is not None + return BatchEnvSampler( + sampling_parallelism=env_attr.env_sampling_parallelism, + port=env_attr.port, + min_env_samples=env_attr.min_env_samples, + grace_factor=env_attr.grace_factor, + eval_parallelism=env_attr.env_eval_parallelism, + logger=env_attr.logger, ) else: env_sampler = rl_component_bundle.env_sampler - if train_mode != "simple": + if rl_component_bundle.device_mapping is not None: for policy_name, device_name in rl_component_bundle.device_mapping.items(): env_sampler.assign_policy_to_device(policy_name, get_torch_device(device_name)) + return env_sampler + + +def main(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes, args: argparse.Namespace) -> None: + if args.evaluate_only: + evaluate_only_workflow(rl_component_bundle, env_attr) + else: + training_workflow(rl_component_bundle, env_attr) + + +def training_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None: + env_attr.logger.info("Start training workflow.") + + env_sampler = _get_env_sampler(rl_component_bundle, env_attr) # evaluation schedule - eval_schedule = list_or_none(get_env("EVAL_SCHEDULE", required=False)) - logger.info(f"Policy will be evaluated at the end of episodes {eval_schedule}") + env_attr.logger.info(f"Policy will be evaluated at the end of episodes {env_attr.eval_schedule}") eval_point_index = 0 training_manager = TrainingManager( rl_component_bundle=rl_component_bundle, - explicit_assign_device=(train_mode == "simple"), - proxy_address=None - if train_mode == "simple" - else ( - get_env("TRAIN_PROXY_HOST"), - int(get_env("TRAIN_PROXY_FRONTEND_PORT")), - ), - logger=logger, + explicit_assign_device=(env_attr.train_mode == "simple"), + proxy_address=None if env_attr.train_mode == "simple" else env_attr.proxy_address, + logger=env_attr.logger, ) - load_path = get_env("LOAD_PATH", required=False) - load_episode = int_or_none(get_env("LOAD_EPISODE", required=False)) - if load_path: - assert isinstance(load_path, str) + if env_attr.load_path: + assert isinstance(env_attr.load_path, str) - ep = load_episode if load_episode is not None else get_latest_ep(load_path) - path = os.path.join(load_path, str(ep)) + ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path) + path = os.path.join(env_attr.load_path, str(ep)) loaded = env_sampler.load_policy_state(path) - logger.info(f"Loaded policies {loaded} into env sampler from {path}") + env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}") loaded = training_manager.load(path) - logger.info(f"Loaded trainers {loaded} from {path}") + env_attr.logger.info(f"Loaded trainers {loaded} from {path}") start_ep = ep + 1 else: start_ep = 1 - checkpoint_path = get_env("CHECKPOINT_PATH", required=False) - checkpoint_interval = int_or_none(get_env("CHECKPOINT_INTERVAL", required=False)) - # main loop - for ep in range(start_ep, num_episodes + 1): - collect_time = training_time = 0 + for ep in range(start_ep, env_attr.num_episodes + 1): + collect_time = training_time = 0.0 total_experiences: List[List[ExpElement]] = [] total_info_list: List[dict] = [] n_sample = 0 - while n_sample < min_n_sample: + while n_sample < env_attr.min_n_sample: tc0 = time.time() result = env_sampler.sample( - policy_state=training_manager.get_policy_state() if not is_single_thread else None, - num_steps=num_steps, + policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None, + num_steps=env_attr.num_steps, ) experiences: List[List[ExpElement]] = result["experiences"] info_list: List[dict] = result["info"] @@ -128,23 +169,25 @@ def training_workflow(rl_component_bundle: RLComponentBundle) -> None: env_sampler.post_collect(total_info_list, ep) - logger.info(f"Roll-out completed for episode {ep}. Training started...") + env_attr.logger.info(f"Roll-out completed for episode {ep}. Training started...") tu0 = time.time() training_manager.record_experiences(total_experiences) training_manager.train_step() - if checkpoint_path and (checkpoint_interval is None or ep % checkpoint_interval == 0): - assert isinstance(checkpoint_path, str) - pth = os.path.join(checkpoint_path, str(ep)) + if env_attr.checkpoint_path and (not env_attr.checkpoint_interval or ep % env_attr.checkpoint_interval == 0): + assert isinstance(env_attr.checkpoint_path, str) + pth = os.path.join(env_attr.checkpoint_path, str(ep)) training_manager.save(pth) - logger.info(f"All trainer states saved under {pth}") + env_attr.logger.info(f"All trainer states saved under {pth}") training_time += time.time() - tu0 # performance details - logger.info(f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds") - if eval_schedule and ep == eval_schedule[eval_point_index]: + env_attr.logger.info( + f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds", + ) + if env_attr.eval_schedule and ep == env_attr.eval_schedule[eval_point_index]: eval_point_index += 1 result = env_sampler.eval( - policy_state=training_manager.get_policy_state() if not is_single_thread else None, + policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None, ) env_sampler.post_evaluate(result["info"], ep) @@ -153,42 +196,19 @@ def training_workflow(rl_component_bundle: RLComponentBundle) -> None: training_manager.exit() -def evaluate_only_workflow(rl_component_bundle: RLComponentBundle) -> None: - logger = LoggerV2( - "MAIN", - dump_path=get_env("LOG_PATH"), - dump_mode="a", - stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"), - file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"), - ) - logger.info("Start evaluate only workflow.") - - env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False)) - env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False)) - parallel_rollout = env_sampling_parallelism is not None or env_eval_parallelism is not None - - if parallel_rollout: - env_sampler = BatchEnvSampler( - sampling_parallelism=env_sampling_parallelism, - port=int(get_env("ROLLOUT_CONTROLLER_PORT")), - min_env_samples=int_or_none(get_env("MIN_ENV_SAMPLES", required=False)), - grace_factor=float_or_none(get_env("GRACE_FACTOR", required=False)), - eval_parallelism=env_eval_parallelism, - logger=logger, - ) - else: - env_sampler = rl_component_bundle.env_sampler +def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None: + env_attr.logger.info("Start evaluate only workflow.") + + env_sampler = _get_env_sampler(rl_component_bundle, env_attr) - load_path = get_env("LOAD_PATH", required=False) - load_episode = int_or_none(get_env("LOAD_EPISODE", required=False)) - if load_path: - assert isinstance(load_path, str) + if env_attr.load_path: + assert isinstance(env_attr.load_path, str) - ep = load_episode if load_episode is not None else get_latest_ep(load_path) - path = os.path.join(load_path, str(ep)) + ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path) + path = os.path.join(env_attr.load_path, str(ep)) loaded = env_sampler.load_policy_state(path) - logger.info(f"Loaded policies {loaded} into env sampler from {path}") + env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}") result = env_sampler.eval() env_sampler.post_evaluate(result["info"], -1) @@ -198,11 +218,9 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle) -> None: if __name__ == "__main__": - scenario_path = get_env("SCENARIO_PATH") + scenario_path = env_str_helper(get_env("SCENARIO_PATH")) scenario_path = os.path.normpath(scenario_path) sys.path.insert(0, os.path.dirname(scenario_path)) module = importlib.import_module(os.path.basename(scenario_path)) - rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls") - rl_component_bundle = rl_component_bundle_cls() - main(rl_component_bundle, args=get_args()) + main(getattr(module, "rl_component_bundle"), WorkflowEnvAttributes(), args=_get_args()) diff --git a/maro/rl/workflows/rollout_worker.py b/maro/rl/workflows/rollout_worker.py index 94da47bde..8343873b3 100644 --- a/maro/rl/workflows/rollout_worker.py +++ b/maro/rl/workflows/rollout_worker.py @@ -4,23 +4,22 @@ import importlib import os import sys -from typing import Type from maro.rl.rl_component.rl_component_bundle import RLComponentBundle from maro.rl.rollout import RolloutWorker from maro.rl.utils.common import get_env, int_or_none +from maro.rl.workflows.utils import env_str_helper from maro.utils import LoggerV2 if __name__ == "__main__": - scenario_path = get_env("SCENARIO_PATH") + scenario_path = env_str_helper(get_env("SCENARIO_PATH")) scenario_path = os.path.normpath(scenario_path) sys.path.insert(0, os.path.dirname(scenario_path)) module = importlib.import_module(os.path.basename(scenario_path)) - rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls") - rl_component_bundle = rl_component_bundle_cls() + rl_component_bundle: RLComponentBundle = getattr(module, "rl_component_bundle") - worker_idx = int_or_none(get_env("ID")) + worker_idx = int(env_str_helper(get_env("ID"))) logger = LoggerV2( f"ROLLOUT-WORKER.{worker_idx}", dump_path=get_env("LOG_PATH"), @@ -31,7 +30,7 @@ worker = RolloutWorker( idx=worker_idx, rl_component_bundle=rl_component_bundle, - producer_host=get_env("ROLLOUT_CONTROLLER_HOST"), + producer_host=env_str_helper(get_env("ROLLOUT_CONTROLLER_HOST")), producer_port=int_or_none(get_env("ROLLOUT_CONTROLLER_PORT")), logger=logger, ) diff --git a/maro/rl/workflows/train_worker.py b/maro/rl/workflows/train_worker.py index f7aa3753c..4565c5b72 100644 --- a/maro/rl/workflows/train_worker.py +++ b/maro/rl/workflows/train_worker.py @@ -4,21 +4,20 @@ import importlib import os import sys -from typing import Type from maro.rl.rl_component.rl_component_bundle import RLComponentBundle from maro.rl.training import TrainOpsWorker from maro.rl.utils.common import get_env, int_or_none +from maro.rl.workflows.utils import env_str_helper from maro.utils import LoggerV2 if __name__ == "__main__": - scenario_path = get_env("SCENARIO_PATH") + scenario_path = env_str_helper(get_env("SCENARIO_PATH")) scenario_path = os.path.normpath(scenario_path) sys.path.insert(0, os.path.dirname(scenario_path)) module = importlib.import_module(os.path.basename(scenario_path)) - rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls") - rl_component_bundle = rl_component_bundle_cls() + rl_component_bundle: RLComponentBundle = getattr(module, "rl_component_bundle") worker_idx = int_or_none(get_env("ID")) logger = LoggerV2( @@ -29,9 +28,9 @@ file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"), ) worker = TrainOpsWorker( - idx=int_or_none(get_env("ID")), + idx=int(env_str_helper(get_env("ID"))), rl_component_bundle=rl_component_bundle, - producer_host=get_env("TRAIN_PROXY_HOST"), + producer_host=env_str_helper(get_env("TRAIN_PROXY_HOST")), producer_port=int_or_none(get_env("TRAIN_PROXY_BACKEND_PORT")), logger=logger, ) diff --git a/maro/rl/workflows/utils.py b/maro/rl/workflows/utils.py new file mode 100644 index 000000000..accfbe86f --- /dev/null +++ b/maro/rl/workflows/utils.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Optional + + +def env_str_helper(string: Optional[str]) -> str: + assert string is not None + return string diff --git a/maro/simulator/abs_core.py b/maro/simulator/abs_core.py index b47d94baa..f50ca53cb 100644 --- a/maro/simulator/abs_core.py +++ b/maro/simulator/abs_core.py @@ -72,7 +72,7 @@ def business_engine(self) -> AbsBusinessEngine: return self._business_engine @abstractmethod - def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]: + def step(self, action) -> Tuple[Optional[dict], Optional[list], bool]: """Push the environment to next step with action. Args: diff --git a/notebooks/container_inventory_management/rl_formulation.ipynb b/notebooks/container_inventory_management/rl_formulation.ipynb new file mode 100644 index 000000000..c49c51e58 --- /dev/null +++ b/notebooks/container_inventory_management/rl_formulation.ipynb @@ -0,0 +1,452 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quick Start\n", + "\n", + "This notebook demonstrates the use of MARO's RL toolkit to optimize container inventory management (CIM). The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Import necessary packages\n", + "from typing import Any, Dict, List, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from torch.optim import Adam, RMSprop\n", + "\n", + "from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet\n", + "from maro.rl.policy import DiscretePolicyGradient\n", + "from maro.rl.rl_component.rl_component_bundle import RLComponentBundle\n", + "from maro.rl.rollout import AbsEnvSampler, CacheElement, ExpElement\n", + "from maro.rl.training import TrainingManager\n", + "from maro.rl.training.algorithms import PPOParams, PPOTrainer\n", + "from maro.simulator import Env\n", + "from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# env and shaping config\n", + "reward_shaping_conf = {\n", + " \"time_window\": 99,\n", + " \"fulfillment_factor\": 1.0,\n", + " \"shortage_factor\": 1.0,\n", + " \"time_decay\": 0.97,\n", + "}\n", + "state_shaping_conf = {\n", + " \"look_back\": 7,\n", + " \"max_ports_downstream\": 2,\n", + "}\n", + "port_attributes = [\"empty\", \"full\", \"on_shipper\", \"on_consignee\", \"booking\", \"shortage\", \"fulfillment\"]\n", + "vessel_attributes = [\"empty\", \"full\", \"remaining_space\"]\n", + "action_shaping_conf = {\n", + " \"action_space\": [(i - 10) / 10 for i in range(21)],\n", + " \"finite_vessel_space\": True,\n", + " \"has_early_discharge\": True,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment Sampler\n", + "\n", + "An environment sampler defines state, action and reward shaping logic so that policies can interact with the environment." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class CIMEnvSampler(AbsEnvSampler):\n", + " def _get_global_and_agent_state_impl(\n", + " self, event: DecisionEvent, tick: int = None,\n", + " ) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:\n", + " tick = self._env.tick\n", + " vessel_snapshots, port_snapshots = self._env.snapshot_list[\"vessels\"], self._env.snapshot_list[\"ports\"]\n", + " port_idx, vessel_idx = event.port_idx, event.vessel_idx\n", + " ticks = [max(0, tick - rt) for rt in range(state_shaping_conf[\"look_back\"] - 1)]\n", + " future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')\n", + " state = np.concatenate([\n", + " port_snapshots[ticks: [port_idx] + list(future_port_list): port_attributes],\n", + " vessel_snapshots[tick: vessel_idx: vessel_attributes]\n", + " ])\n", + " return state, {port_idx: state}\n", + "\n", + " def _translate_to_env_action(\n", + " self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionEvent,\n", + " ) -> Dict[Any, object]:\n", + " action_space = action_shaping_conf[\"action_space\"]\n", + " finite_vsl_space = action_shaping_conf[\"finite_vessel_space\"]\n", + " has_early_discharge = action_shaping_conf[\"has_early_discharge\"]\n", + "\n", + " port_idx, model_action = list(action_dict.items()).pop()\n", + "\n", + " vsl_idx, action_scope = event.vessel_idx, event.action_scope\n", + " vsl_snapshots = self._env.snapshot_list[\"vessels\"]\n", + " vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float(\"inf\")\n", + "\n", + " percent = abs(action_space[model_action[0]])\n", + " zero_action_idx = len(action_space) / 2 # index corresponding to value zero.\n", + " if model_action < zero_action_idx:\n", + " action_type = ActionType.LOAD\n", + " actual_action = min(round(percent * action_scope.load), vsl_space)\n", + " elif model_action > zero_action_idx:\n", + " action_type = ActionType.DISCHARGE\n", + " early_discharge = vsl_snapshots[self._env.tick:vsl_idx:\"early_discharge\"][0] if has_early_discharge else 0\n", + " plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge\n", + " actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)\n", + " else:\n", + " actual_action, action_type = 0, None\n", + "\n", + " return {port_idx: Action(vsl_idx, int(port_idx), actual_action, action_type)}\n", + "\n", + " def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]:\n", + " start_tick = tick + 1\n", + " ticks = list(range(start_tick, start_tick + reward_shaping_conf[\"time_window\"]))\n", + "\n", + " # Get the ports that took actions at the given tick\n", + " ports = [int(port) for port in list(env_action_dict.keys())]\n", + " port_snapshots = self._env.snapshot_list[\"ports\"]\n", + " future_fulfillment = port_snapshots[ticks:ports:\"fulfillment\"].reshape(len(ticks), -1)\n", + " future_shortage = port_snapshots[ticks:ports:\"shortage\"].reshape(len(ticks), -1)\n", + "\n", + " decay_list = [reward_shaping_conf[\"time_decay\"] ** i for i in range(reward_shaping_conf[\"time_window\"])]\n", + " rewards = np.float32(\n", + " reward_shaping_conf[\"fulfillment_factor\"] * np.dot(future_fulfillment.T, decay_list)\n", + " - reward_shaping_conf[\"shortage_factor\"] * np.dot(future_shortage.T, decay_list)\n", + " )\n", + " return {agent_id: reward for agent_id, reward in zip(ports, rewards)}\n", + "\n", + " def _post_step(self, cache_element: CacheElement) -> None:\n", + " self._info[\"env_metric\"] = self._env.metrics\n", + "\n", + " def _post_eval_step(self, cache_element: CacheElement) -> None:\n", + " self._post_step(cache_element)\n", + "\n", + " def post_collect(self, info_list: list, ep: int) -> None:\n", + " # print the env metric from each rollout worker\n", + " for info in info_list:\n", + " print(f\"env summary (episode {ep}): {info['env_metric']}\")\n", + "\n", + " # print the average env metric\n", + " if len(info_list) > 1:\n", + " metric_keys, num_envs = info_list[0][\"env_metric\"].keys(), len(info_list)\n", + " avg_metric = {key: sum(info[\"env_metric\"][key] for info in info_list) / num_envs for key in metric_keys}\n", + " print(f\"average env summary (episode {ep}): {avg_metric}\")\n", + "\n", + " def post_evaluate(self, info_list: list, ep: int) -> None:\n", + " self.post_collect(info_list, ep)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Policies & Trainers" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "state_dim = (\n", + " (state_shaping_conf[\"look_back\"] + 1) * (state_shaping_conf[\"max_ports_downstream\"] + 1) * len(port_attributes)\n", + " + len(vessel_attributes)\n", + ")\n", + "action_num = len(action_shaping_conf[\"action_space\"])\n", + "\n", + "actor_net_conf = {\n", + " \"hidden_dims\": [256, 128, 64],\n", + " \"activation\": torch.nn.Tanh,\n", + " \"softmax\": True,\n", + " \"batch_norm\": False,\n", + " \"head\": True,\n", + "}\n", + "critic_net_conf = {\n", + " \"hidden_dims\": [256, 128, 64],\n", + " \"output_dim\": 1,\n", + " \"activation\": torch.nn.LeakyReLU,\n", + " \"softmax\": False,\n", + " \"batch_norm\": True,\n", + " \"head\": True,\n", + "}\n", + "\n", + "actor_learning_rate = 0.001\n", + "critic_learning_rate = 0.001\n", + "\n", + "class MyActorNet(DiscreteACBasedNet):\n", + " def __init__(self, state_dim: int, action_num: int) -> None:\n", + " super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)\n", + " self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)\n", + " self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)\n", + "\n", + " def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:\n", + " return self._actor(states)\n", + "\n", + "\n", + "class MyCriticNet(VNet):\n", + " def __init__(self, state_dim: int) -> None:\n", + " super(MyCriticNet, self).__init__(state_dim=state_dim)\n", + " self._critic = FullyConnected(input_dim=state_dim, **critic_net_conf)\n", + " self._optim = RMSprop(self._critic.parameters(), lr=critic_learning_rate)\n", + "\n", + " def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:\n", + " return self._critic(states).squeeze(-1)\n", + "\n", + "def get_ppo_trainer(state_dim: int, name: str) -> PPOTrainer:\n", + " return PPOTrainer(\n", + " name=name,\n", + " reward_discount=.0,\n", + " params=PPOParams(\n", + " get_v_critic_net_func=lambda: MyCriticNet(state_dim),\n", + " grad_iters=10,\n", + " critic_loss_cls=torch.nn.SmoothL1Loss,\n", + " min_logp=None,\n", + " lam=.0,\n", + " clip_ratio=0.1,\n", + " ),\n", + " )\n", + "\n", + "learn_env = Env(scenario=\"cim\", topology=\"toy.4p_ssdd_l0.0\", durations=500)\n", + "test_env = learn_env\n", + "num_agents = len(learn_env.agent_idx_list)\n", + "agent2policy = {agent: f\"ppo_{agent}.policy\"for agent in learn_env.agent_idx_list}\n", + "policies = [DiscretePolicyGradient(name=f\"ppo_{i}.policy\", policy_net=MyActorNet(state_dim, action_num)) for i in range(num_agents)]\n", + "trainers = [get_ppo_trainer(state_dim, f\"ppo_{i}\") for i in range(num_agents)]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RL component bundle\n", + "\n", + "An RL component bundle integrate all necessary resources to launch a learning loop." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "rl_component_bundle = RLComponentBundle(\n", + " env_sampler=CIMEnvSampler(\n", + " learn_env=learn_env,\n", + " test_env=test_env,\n", + " policies=policies,\n", + " agent2policy=agent2policy,\n", + " reward_eval_delay=reward_shaping_conf[\"time_window\"],\n", + " ),\n", + " agent2policy=agent2policy,\n", + " policies=policies,\n", + " trainers=trainers,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Learning Loop\n", + "\n", + "This code cell demonstrates a typical single-threaded training workflow." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting result:\n", + "env summary (episode 1): {'order_requirements': 1000000, 'container_shortage': 688632, 'operation_number': 1940226}\n", + "\n", + "Collecting result:\n", + "env summary (episode 2): {'order_requirements': 1000000, 'container_shortage': 601337, 'operation_number': 2030600}\n", + "\n", + "Collecting result:\n", + "env summary (episode 3): {'order_requirements': 1000000, 'container_shortage': 544572, 'operation_number': 1737291}\n", + "\n", + "Collecting result:\n", + "env summary (episode 4): {'order_requirements': 1000000, 'container_shortage': 545506, 'operation_number': 2008160}\n", + "\n", + "Collecting result:\n", + "env summary (episode 5): {'order_requirements': 1000000, 'container_shortage': 442000, 'operation_number': 1935439}\n", + "\n", + "Evaluation result:\n", + "env summary (episode 5): {'order_requirements': 1000000, 'container_shortage': 533699, 'operation_number': 891248}\n", + "\n", + "Collecting result:\n", + "env summary (episode 6): {'order_requirements': 1000000, 'container_shortage': 448461, 'operation_number': 1918664}\n", + "\n", + "Collecting result:\n", + "env summary (episode 7): {'order_requirements': 1000000, 'container_shortage': 469874, 'operation_number': 1745973}\n", + "\n", + "Collecting result:\n", + "env summary (episode 8): {'order_requirements': 1000000, 'container_shortage': 364469, 'operation_number': 1974592}\n", + "\n", + "Collecting result:\n", + "env summary (episode 9): {'order_requirements': 1000000, 'container_shortage': 425449, 'operation_number': 1821885}\n", + "\n", + "Collecting result:\n", + "env summary (episode 10): {'order_requirements': 1000000, 'container_shortage': 386687, 'operation_number': 1798356}\n", + "\n", + "Evaluation result:\n", + "env summary (episode 10): {'order_requirements': 1000000, 'container_shortage': 950000, 'operation_number': 0}\n", + "\n", + "Collecting result:\n", + "env summary (episode 11): {'order_requirements': 1000000, 'container_shortage': 403236, 'operation_number': 1742253}\n", + "\n", + "Collecting result:\n", + "env summary (episode 12): {'order_requirements': 1000000, 'container_shortage': 373426, 'operation_number': 1682848}\n", + "\n", + "Collecting result:\n", + "env summary (episode 13): {'order_requirements': 1000000, 'container_shortage': 357254, 'operation_number': 1845318}\n", + "\n", + "Collecting result:\n", + "env summary (episode 14): {'order_requirements': 1000000, 'container_shortage': 215681, 'operation_number': 1969606}\n", + "\n", + "Collecting result:\n", + "env summary (episode 15): {'order_requirements': 1000000, 'container_shortage': 288347, 'operation_number': 1739670}\n", + "\n", + "Evaluation result:\n", + "env summary (episode 15): {'order_requirements': 1000000, 'container_shortage': 639517, 'operation_number': 680980}\n", + "\n", + "Collecting result:\n", + "env summary (episode 16): {'order_requirements': 1000000, 'container_shortage': 258659, 'operation_number': 1747509}\n", + "\n", + "Collecting result:\n", + "env summary (episode 17): {'order_requirements': 1000000, 'container_shortage': 202262, 'operation_number': 1982958}\n", + "\n", + "Collecting result:\n", + "env summary (episode 18): {'order_requirements': 1000000, 'container_shortage': 209018, 'operation_number': 1765574}\n", + "\n", + "Collecting result:\n", + "env summary (episode 19): {'order_requirements': 1000000, 'container_shortage': 256471, 'operation_number': 1764379}\n", + "\n", + "Collecting result:\n", + "env summary (episode 20): {'order_requirements': 1000000, 'container_shortage': 259231, 'operation_number': 1737222}\n", + "\n", + "Evaluation result:\n", + "env summary (episode 20): {'order_requirements': 1000000, 'container_shortage': 9000, 'operation_number': 1974766}\n", + "\n", + "Collecting result:\n", + "env summary (episode 21): {'order_requirements': 1000000, 'container_shortage': 268553, 'operation_number': 1697234}\n", + "\n", + "Collecting result:\n", + "env summary (episode 22): {'order_requirements': 1000000, 'container_shortage': 212987, 'operation_number': 1788601}\n", + "\n", + "Collecting result:\n", + "env summary (episode 23): {'order_requirements': 1000000, 'container_shortage': 234729, 'operation_number': 1803468}\n", + "\n", + "Collecting result:\n", + "env summary (episode 24): {'order_requirements': 1000000, 'container_shortage': 224261, 'operation_number': 1736261}\n", + "\n", + "Collecting result:\n", + "env summary (episode 25): {'order_requirements': 1000000, 'container_shortage': 191424, 'operation_number': 1952505}\n", + "\n", + "Evaluation result:\n", + "env summary (episode 25): {'order_requirements': 1000000, 'container_shortage': 606940, 'operation_number': 710472}\n", + "\n", + "Collecting result:\n", + "env summary (episode 26): {'order_requirements': 1000000, 'container_shortage': 223272, 'operation_number': 1895614}\n", + "\n", + "Collecting result:\n", + "env summary (episode 27): {'order_requirements': 1000000, 'container_shortage': 427395, 'operation_number': 1351830}\n", + "\n", + "Collecting result:\n", + "env summary (episode 28): {'order_requirements': 1000000, 'container_shortage': 266455, 'operation_number': 1924877}\n", + "\n", + "Collecting result:\n", + "env summary (episode 29): {'order_requirements': 1000000, 'container_shortage': 362452, 'operation_number': 1747022}\n", + "\n", + "Collecting result:\n", + "env summary (episode 30): {'order_requirements': 1000000, 'container_shortage': 320532, 'operation_number': 1506639}\n", + "\n", + "Evaluation result:\n", + "env summary (episode 30): {'order_requirements': 1000000, 'container_shortage': 639581, 'operation_number': 681708}\n", + "\n" + ] + } + ], + "source": [ + "env_sampler = rl_component_bundle.env_sampler\n", + "\n", + "num_episodes = 30\n", + "eval_schedule = [5, 10, 15, 20, 25, 30]\n", + "eval_point_index = 0\n", + "\n", + "training_manager = TrainingManager(rl_component_bundle=rl_component_bundle)\n", + "\n", + "# main loop\n", + "for ep in range(1, num_episodes + 1):\n", + " result = env_sampler.sample()\n", + " experiences: List[List[ExpElement]] = result[\"experiences\"]\n", + " info_list: List[dict] = result[\"info\"]\n", + " \n", + " print(\"Collecting result:\")\n", + " env_sampler.post_collect(info_list, ep)\n", + " print()\n", + "\n", + " training_manager.record_experiences(experiences)\n", + " training_manager.train_step()\n", + "\n", + " if ep == eval_schedule[eval_point_index]:\n", + " eval_point_index += 1\n", + " result = env_sampler.eval()\n", + " \n", + " print(\"Evaluation result:\")\n", + " env_sampler.post_evaluate(result[\"info\"], ep)\n", + " print()\n", + "\n", + "training_manager.exit()" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "8f57a09d39b50edfb56e79199ef40583334d721b06ead0e38a39e7e79092073c" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/requirements.test.txt b/tests/requirements.test.txt index c7ac76f45..e44b03370 100644 --- a/tests/requirements.test.txt +++ b/tests/requirements.test.txt @@ -6,7 +6,7 @@ deepdiff>=5.7.0 geopy>=2.0.0 holidays>=0.10.3 kubernetes>=21.7.0 -numpy>=1.19.5 +numpy>=1.19.5,<1.24.0 pandas>=0.25.3 paramiko>=2.9.2 pytest>=7.1.2 diff --git a/tests/test_env.py b/tests/test_env.py index d96950d3c..ce257b831 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -326,7 +326,7 @@ def test_early_stop(self): msg=f"env should stop at tick 6, but {env.tick}", ) - # avaiable snapshot should be 7 (0-6) + # available snapshot should be 7 (0-6) states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10) self.assertEqual(