diff --git a/examples/rl/cim/env_sampler.py b/examples/rl/cim/env_sampler.py index 82a7cc7cd..d5ab368eb 100644 --- a/examples/rl/cim/env_sampler.py +++ b/examples/rl/cim/env_sampler.py @@ -32,7 +32,7 @@ def get_state(self, tick=None): vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"] port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)] - future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') + future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') state = np.concatenate([ port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes], vessel_snapshots[tick : vessel_idx : vessel_attributes] @@ -55,7 +55,7 @@ def get_env_actions(self, action_by_agent): vsl_snapshots = self.env.snapshot_list["vessels"] vsl_space = vsl_snapshots[self.env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf") - model_action = action["action"] if isinstance(action, dict) else action + model_action = action["action"] if isinstance(action, dict) else action percent = abs(action_space[model_action]) zero_action_idx = len(action_space) / 2 # index corresponding to value zero. if model_action < zero_action_idx: @@ -112,5 +112,5 @@ def get_env_sampler(): get_policy_func_dict=policy_func_dict, agent2policy=agent2policy, reward_eval_delay=reward_shaping_conf["time_window"], - parallel_inference=True + parallel_inference=False ) diff --git a/examples/rl/cim_v2/README.md b/examples/rl/cim_v2/README.md new file mode 100644 index 000000000..5113acd18 --- /dev/null +++ b/examples/rl/cim_v2/README.md @@ -0,0 +1,9 @@ +# Container Inventory Management + +This example demonstrates the use of MARO's RL toolkit to optimize container inventory management. The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time. In this folder you can find: +* ``config.py``, which contains environment and policy configurations for the scenario; +* ``env_sampler.py``, which defines state, action and reward shaping in the ``CIMEnvSampler`` class; +* ``policies.py``, which defines the Q-net for DQN and the network components for Actor-Critic; +* ``callbacks.py``, which defines routines to be invoked at the end of training or evaluation episodes. + +The scripts for running the learning workflows can be found under ``examples/rl/workflows``. See ``README`` under ``examples/rl`` for details about the general applicability of these scripts. We recommend that you follow this example to write your own scenarios. \ No newline at end of file diff --git a/examples/rl/cim_v2/__init__.py b/examples/rl/cim_v2/__init__.py new file mode 100644 index 000000000..c3e93fabd --- /dev/null +++ b/examples/rl/cim_v2/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .callbacks import post_collect, post_evaluate +from .env_sampler import agent2policy, get_env_sampler +from .policies import policy_func_dict + +__all__ = ["agent2policy", "post_collect", "post_evaluate", "get_env_sampler", "policy_func_dict"] diff --git a/examples/rl/cim_v2/callbacks.py b/examples/rl/cim_v2/callbacks.py new file mode 100644 index 000000000..a5d6d1edb --- /dev/null +++ b/examples/rl/cim_v2/callbacks.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import time +from os import makedirs +from os.path import dirname, join, realpath + +log_dir = join(dirname(realpath(__file__)), "log", str(time.time())) +makedirs(log_dir, exist_ok=True) + + +def post_collect(trackers, ep, segment): + # print the env metric from each rollout worker + for tracker in trackers: + print(f"env summary (episode {ep}, segment {segment}): {tracker['env_metric']}") + + # print the average env metric + if len(trackers) > 1: + metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers) + avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys} + print(f"average env summary (episode {ep}, segment {segment}): {avg_metric}") + + +def post_evaluate(trackers, ep): + # print the env metric from each rollout worker + for tracker in trackers: + print(f"env summary (episode {ep}): {tracker['env_metric']}") + + # print the average env metric + if len(trackers) > 1: + metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers) + avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys} + print(f"average env summary (episode {ep}): {avg_metric}") diff --git a/examples/rl/cim_v2/config.py b/examples/rl/cim_v2/config.py new file mode 100644 index 000000000..a43c61c99 --- /dev/null +++ b/examples/rl/cim_v2/config.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +from torch.optim import Adam, RMSprop + +from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy + + +env_conf = { + "scenario": "cim", + "topology": "toy.4p_ssdd_l0.0", + "durations": 560 +} + +port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"] +vessel_attributes = ["empty", "full", "remaining_space"] + +state_shaping_conf = { + "look_back": 7, + "max_ports_downstream": 2 +} + +action_shaping_conf = { + "action_space": [(i - 10) / 10 for i in range(21)], + "finite_vessel_space": True, + "has_early_discharge": True +} + +reward_shaping_conf = { + "time_window": 99, + "fulfillment_factor": 1.0, + "shortage_factor": 1.0, + "time_decay": 0.97 +} + +# obtain state dimension from a temporary env_wrapper instance +state_dim = ( + (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes) + + len(vessel_attributes) +) + +############################################## POLICIES ############################################### + +algorithm = "ac" + +# DQN settings +q_net_conf = { + "input_dim": state_dim, + "hidden_dims": [256, 128, 64, 32], + "output_dim": len(action_shaping_conf["action_space"]), + "activation": torch.nn.LeakyReLU, + "softmax": False, + "batch_norm": True, + "skip_connection": False, + "head": True, + "dropout_p": 0.0 +} + +q_net_optim_conf = (RMSprop, {"lr": 0.05}) + +dqn_conf = { + "reward_discount": .0, + "update_target_every": 5, + "num_epochs": 10, + "soft_update_coef": 0.1, + "double": False, + "exploration_strategy": (epsilon_greedy, {"epsilon": 0.4}), + "exploration_scheduling_options": [( + "epsilon", MultiLinearExplorationScheduler, { + "splits": [(2, 0.32)], + "initial_value": 0.4, + "last_ep": 5, + "final_value": 0.0, + } + )], + "replay_memory_capacity": 10000, + "random_overwrite": False, + "warmup": 100, + "rollout_batch_size": 128, + "train_batch_size": 32, + # "prioritized_replay_kwargs": { + # "alpha": 0.6, + # "beta": 0.4, + # "beta_step": 0.001, + # "max_priority": 1e8 + # } +} + + +# AC settings +actor_net_conf = { + "input_dim": state_dim, + "hidden_dims": [256, 128, 64], + "output_dim": len(action_shaping_conf["action_space"]), + "activation": torch.nn.Tanh, + "softmax": True, + "batch_norm": False, + "head": True +} + +critic_net_conf = { + "input_dim": state_dim, + "hidden_dims": [256, 128, 64], + "output_dim": 1, + "activation": torch.nn.LeakyReLU, + "softmax": False, + "batch_norm": True, + "head": True +} + +actor_optim_conf = (Adam, {"lr": 0.001}) +critic_optim_conf = (RMSprop, {"lr": 0.001}) + +ac_conf = { + "reward_discount": .0, + "grad_iters": 10, + "critic_loss_cls": torch.nn.SmoothL1Loss, + "min_logp": None, + "critic_loss_coef": 0.1, + "entropy_coef": 0.01, + # "clip_ratio": 0.8 # for PPO + "lam": .0, + "get_loss_on_rollout": False +} diff --git a/examples/rl/cim_v2/env_sampler.py b/examples/rl/cim_v2/env_sampler.py new file mode 100644 index 000000000..41e68d87d --- /dev/null +++ b/examples/rl/cim_v2/env_sampler.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import sys + +import numpy as np + +from maro.rl.learning.env_sampler_v2 import AbsEnvSampler +from maro.simulator import Env +from maro.simulator.scenarios.cim.common import Action, ActionType + +cim_path = os.path.dirname(os.path.realpath(__file__)) +if cim_path not in sys.path: + sys.path.insert(0, cim_path) + +from config import ( + action_shaping_conf, algorithm, env_conf, port_attributes, reward_shaping_conf, state_shaping_conf, + vessel_attributes +) +from policies import policy_func_dict + + +class CIMEnvSampler(AbsEnvSampler): + def get_state(self, tick=None): + """ + The state vector includes shortage and remaining vessel space over the past k days (where k is the "look_back" + value in ``state_shaping_conf``), as well as all downstream port features. + """ + if tick is None: + tick = self._env.tick + vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"] + port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx + ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)] + future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') + state = np.concatenate([ + port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes], + vessel_snapshots[tick : vessel_idx : vessel_attributes] + ]) + return {port_idx: state} + + def get_env_actions(self, action_by_agent): + """ + The policy output is an integer from [0, 20] which is to be interpreted as the index of ``action_space`` in + ``action_shaping_conf``. For example, action 5 corresponds to -0.5, which means loading 50% of the containers + available at the current port to the vessel, while action 18 corresponds to 0.8, which means loading 80% of the + containers on the vessel to the port. Note that action 10 corresponds 0.0, which means doing nothing. + """ + action_space = action_shaping_conf["action_space"] + finite_vsl_space = action_shaping_conf["finite_vessel_space"] + has_early_discharge = action_shaping_conf["has_early_discharge"] + + port_idx, action = list(action_by_agent.items()).pop() + vsl_idx, action_scope = self.event.vessel_idx, self.event.action_scope + vsl_snapshots = self._env.snapshot_list["vessels"] + vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf") + + model_action = action["action"] if isinstance(action, dict) else action + percent = abs(action_space[model_action]) + zero_action_idx = len(action_space) / 2 # index corresponding to value zero. + if model_action < zero_action_idx: + action_type = ActionType.LOAD + actual_action = min(round(percent * action_scope.load), vsl_space) + elif model_action > zero_action_idx: + action_type = ActionType.DISCHARGE + early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0 + plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge + actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge) + else: + actual_action, action_type = 0, None + + return [Action(port_idx=port_idx, vessel_idx=vsl_idx, quantity=actual_action, action_type=action_type)] + + def get_reward(self, actions, tick): + """ + The reward is defined as a linear combination of fulfillment and shortage measures. The fulfillment and + shortage measures are the sums of fulfillment and shortage values over the next k days, respectively, each + adjusted with exponential decay factors (using the "time_decay" value in ``reward_shaping_conf``) to put more + emphasis on the near future. Here k is the "time_window" value in ``reward_shaping_conf``. The linear + combination coefficients are given by "fulfillment_factor" and "shortage_factor" in ``reward_shaping_conf``. + """ + start_tick = tick + 1 + ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"])) + + # Get the ports that took actions at the given tick + ports = [action.port_idx for action in actions] + port_snapshots = self._env.snapshot_list["ports"] + future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1) + future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1) + + decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])] + rewards = np.float32( + reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list) + - reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list) + ) + return {agent_id: reward for agent_id, reward in zip(ports, rewards)} + + def post_step(self, state, action, env_action, reward, tick): + """ + The environment sampler contains a "tracker" dict inherited from the "AbsEnvSampler" base class, which can + be used to record any information one wishes to keep track of during a roll-out episode. Here we simply record + the latest env metric without keeping the history for logging purposes. + """ + self._tracker["env_metric"] = self._env.metrics + + +agent2policy = {agent: f"{algorithm}.{agent}" for agent in Env(**env_conf).agent_idx_list} + +def get_env_sampler(): + return CIMEnvSampler( + get_env=lambda: Env(**env_conf), + get_policy_func_dict=policy_func_dict, + agent2policy=agent2policy, + reward_eval_delay=reward_shaping_conf["time_window"], + parallel_inference=False + ) diff --git a/examples/rl/cim_v2/policies.py b/examples/rl/cim_v2/policies.py new file mode 100644 index 000000000..712ad22dd --- /dev/null +++ b/examples/rl/cim_v2/policies.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import sys +from typing import Any, Tuple + +import torch + +from maro.rl.modeling import FullyConnected +from maro.rl.modeling_v2 import DiscreteQNetwork, DiscreteVActorCriticNet +from maro.rl.policy_v2 import DQN, DiscreteActorCritic + +cim_path = os.path.dirname(os.path.realpath(__file__)) +if cim_path not in sys.path: + sys.path.insert(0, cim_path) +from config import ( + ac_conf, actor_net_conf, actor_optim_conf, algorithm, critic_net_conf, critic_optim_conf, + dqn_conf, q_net_conf, + q_net_optim_conf +) + + +class MyQNet(DiscreteQNetwork): + def _forward_unimplemented(self, *input: Any) -> None: + pass + + def __init__(self) -> None: + super(MyQNet, self).__init__(state_dim=q_net_conf["input_dim"], action_num=q_net_conf["output_dim"]) + self.fc = FullyConnected(**q_net_conf) + self.optim = q_net_optim_conf[0](self.fc.parameters(), **q_net_optim_conf[1]) + + def forward(self, x): + raise NotImplementedError + + def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + return self.fc(states) + + def step(self, loss: torch.tensor) -> None: + self.optim.zero_grad() + loss.backward() + self.optim.step() + + def get_gradients(self, loss: torch.tensor) -> torch.tensor: + self.optim.zero_grad() + loss.backward() + return {name: param.grad for name, param in self.named_parameters()} + + def apply_gradients(self, grad: dict) -> None: + for name, param in self.named_parameters(): + param.grad = grad[name] + + self.optim.step() + + def get_state(self) -> object: + return {"network": self.state_dict(), "optim": self.optim.state_dict()} + + def set_state(self, state: dict) -> None: + self.load_state_dict(state["network"]) + self.optim.load_state_dict(state["optim"]) + + +class MyACNet(DiscreteVActorCriticNet): + def __init__(self) -> None: + super(MyACNet, self).__init__(state_dim=actor_net_conf["input_dim"], action_num=actor_net_conf["output_dim"]) + self.actor = FullyConnected(**actor_net_conf) + self.critic = FullyConnected(**critic_net_conf) + self.actor_optim = actor_optim_conf[0](self.actor.parameters(), **actor_optim_conf[1]) + self.critic_optim = critic_optim_conf[0](self.critic.parameters(), **critic_optim_conf[1]) + + def _get_v_critic(self, states: torch.Tensor) -> torch.Tensor: + return self.critic(states).squeeze(-1) + + def _get_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + return self.actor(states) + + def step(self, loss: torch.tensor) -> None: + self.actor_optim.zero_grad() + self.critic_optim.zero_grad() + loss.backward() + self.actor_optim.step() + self.critic_optim.step() + + def get_gradients(self, loss: torch.tensor) -> torch.tensor: + self.actor_optim.zero_grad() + self.critic_optim.zero_grad() + loss.backward() + return {name: param.grad for name, param in self.named_parameters()} + + def apply_gradients(self, grad: dict) -> None: + for name, param in self.named_parameters(): + param.grad = grad[name] + + self.actor_optim.step() + self.critic_optim.step() + + def get_state(self) -> dict: + return { + "network": self.state_dict(), + "actor_optim": self.actor_optim.state_dict(), + "critic_optim": self.critic_optim.state_dict() + } + + def set_state(self, state: dict) -> None: + self.load_state_dict(state["network"]) + self.actor_optim.load_state_dict(state["actor_optim"]) + self.critic_optim.load_state_dict(state["critic_optim"]) + + def _forward_unimplemented(self, *input: Any) -> None: + pass + + +if algorithm == "dqn": + policy_func_dict = { + f"{algorithm}.{i}": lambda name: DQN(name, MyQNet(), **dqn_conf) for i in range(4) + } +elif algorithm == "ac": + policy_func_dict = { + f"{algorithm}.{i}": lambda name: DiscreteActorCritic(name, MyACNet(), **ac_conf) for i in range(4) + } +else: + raise ValueError diff --git a/examples/rl/config.yml b/examples/rl/config.yml index a1a15e640..e6f42a57f 100644 --- a/examples/rl/config.yml +++ b/examples/rl/config.yml @@ -3,12 +3,12 @@ job: cim scenario_dir: "/maro/examples" -scenario: cim +scenario: cim_v2 load_policy_dir: "/maro/examples/checkpoints/cim" checkpoint_dir: "/maro/examples/checkpoints/cim" log_dir: "/maro/examples/logs/cim" -mode: sync # single, sync, async -num_episodes: 10 +mode: single # single, sync, async +num_episodes: 50 eval_schedule: 5 sync: num_rollouts: 3 diff --git a/maro/rl/learning/env_sampler_v2.py b/maro/rl/learning/env_sampler_v2.py new file mode 100644 index 000000000..bd211a05c --- /dev/null +++ b/maro/rl/learning/env_sampler_v2.py @@ -0,0 +1,590 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from multiprocessing import Pipe, Process +from os import getcwd, path +from typing import Callable, Dict + +import numpy as np + +from maro.communication import Proxy, SessionMessage, SessionType +from maro.rl.policy_v2 import RLPolicy +from maro.rl.utils import MsgKey, MsgTag +from maro.simulator import Env +from maro.utils import Logger, clone + +from .helpers import get_rollout_finish_msg + + +class AbsAgentWrapper(object): + def __init__(self, aid2pid: Dict[str, str]) -> None: + self._aid2pid = aid2pid + + @abstractmethod + def load(self, dir: str) -> None: + pass + + @abstractmethod + def choose_action(self, states_by_agent_id: Dict[str, np.ndarray]) -> dict: + pass + + @abstractmethod + def set_policy_states(self, states_by_policy_id: dict) -> None: + pass + + @abstractmethod + def explore(self) -> None: + pass + + @abstractmethod + def exploit(self) -> None: + pass + + @abstractmethod + def exploration_step(self) -> None: + pass + + @abstractmethod + def get_rollout_info(self) -> dict: + pass + + @abstractmethod + def get_exploration_params(self) -> dict: + pass + + @abstractmethod + def record_transition( + self, agent: str, state: np.ndarray, action: dict, reward: float, next_state: np.ndarray, terminal: bool + ) -> None: + pass + + @abstractmethod + def improve(self, checkpoint_dir: str = None) -> None: + pass + + +class SimpleAgentWrapper(AbsAgentWrapper): + """Wrapper for multiple agents using multiple policies to expose simple single-agent interfaces.""" + def __init__(self, get_policy_func_dict: Dict[str, Callable], aid2pid: Dict[str, str]) -> None: + super(SimpleAgentWrapper, self).__init__(aid2pid) + + self._pid2policy: Dict[str, RLPolicy] = { + policy_id: func(policy_id) for policy_id, func in get_policy_func_dict.items() + } + self._aid2policy: Dict[str, RLPolicy] = { + agent_id: self._pid2policy[policy_id] for agent_id, policy_id in self._aid2pid.items() + } + + def load(self, dir: str) -> None: + for policy_id, policy in self._pid2policy.items(): + pth = path.join(dir, policy_id) + if path.exists(pth): + policy.load(pth) + + def choose_action(self, states_by_agent_id: Dict[str, np.ndarray]) -> dict: + pid2states, pid2aids = defaultdict(list), defaultdict(list) + for agent_id, state in states_by_agent_id.items(): + policy_id = self._aid2pid[agent_id] + pid2states[policy_id].append(state) + pid2aids[policy_id].append(agent_id) + + actions_by_agent_id = {} + # compute the actions for local policies first while the inferences processes do their work. + for policy_id, policy in self._pid2policy.items(): + if pid2states[policy_id]: + actions_by_agent_id.update( + zip(pid2aids[policy_id], policy(np.vstack(pid2states[policy_id]))) + ) + + return actions_by_agent_id + + def set_policy_states(self, states_by_policy_id: dict) -> None: + for policy_id, state in states_by_policy_id.items(): + self._pid2policy[policy_id].set_state(state) + + def explore(self) -> None: + for policy in self._pid2policy.values(): + policy.explore() + + def exploit(self) -> None: + for policy in self._pid2policy.values(): + policy.exploit() + + def exploration_step(self) -> None: + for policy in self._pid2policy.values(): + if hasattr(policy, "exploration_step"): + policy.exploration_step() + + def get_rollout_info(self) -> dict: + return { + policy_id: policy.get_rollout_info() for policy_id, policy in self._pid2policy.items() + if isinstance(policy, RLPolicy) + } + + def get_exploration_params(self) -> dict: + return { + policy_id: clone(policy.exploration_params) for policy_id, policy in self._pid2policy.items() + if isinstance(policy, RLPolicy) + } + + def record_transition( + self, agent: str, state: np.ndarray, action: dict, reward: float, next_state: np.ndarray, terminal: bool + ) -> None: + if isinstance(self._aid2policy[agent], RLPolicy): + self._aid2policy[agent].record(agent, state, action, reward, next_state, terminal) + + def improve(self, checkpoint_dir: str = None) -> None: + for policy_id, policy in self._pid2policy.items(): + if hasattr(policy, "improve"): + policy.improve() + if checkpoint_dir: + policy.save(path.join(checkpoint_dir, policy_id)) + + +class ParallelAgentWrapper(AbsAgentWrapper): + """Wrapper for multiple agents using multiple policies to expose simple single-agent interfaces. + + The policy instances are distributed across multiple processes to achieve parallel inference. + """ + def __init__(self, get_policy_func_dict: Dict[str, Callable], aid2pid: Dict[str, str]) -> None: + super(ParallelAgentWrapper, self).__init__(aid2pid) + + self._inference_services = [] + self._conn = {} + + def _inference_service(id_, get_policy, conn) -> None: + policy = get_policy(id_) + while True: + msg = conn.recv() + if msg["type"] == "load": + if hasattr(policy, "load"): + policy.load(path.join(msg["dir"], id_)) + elif msg["type"] == "choose_action": + actions = policy(msg["states"]) + conn.send(actions) + elif msg["type"] == "set_state": + if hasattr(policy, "set_state"): + policy.set_state(msg["policy_state"]) + elif msg["type"] == "explore": + policy.explore() + elif msg["type"] == "exploit": + policy.exploit() + elif msg["type"] == "exploration_step": + if hasattr(policy, "exploration_step"): + policy.exploration_step() + elif msg["type"] == "rollout_info": + conn.send(policy.get_rollout_info() if hasattr(policy, "get_rollout_info") else None) + elif msg["type"] == "exploration_params": + conn.send(policy.exploration_params if hasattr(policy, "exploration_params") else None) + elif msg["type"] == "record": + if hasattr(policy, "record"): + policy.record( + msg["agent"], msg["state"], msg["action"], msg["reward"], msg["next_state"], msg["terminal"] + ) + elif msg["type"] == "update": + if hasattr(policy, "update"): + policy.update(msg["loss_info"]) + elif msg["type"] == "learn": + if hasattr(policy, "learn"): + policy.learn(msg["batch"]) + elif msg["type"] == "improve": + if hasattr(policy, "improve"): + policy.improve() + if msg["checkpoint_dir"]: + policy.save(path.join(msg["checkpoint_dir"], id_)) + + for policy_id in get_policy_func_dict: + conn1, conn2 = Pipe() + self._conn[policy_id] = conn1 + host = Process( + target=_inference_service, + args=(policy_id, get_policy_func_dict[policy_id], conn2) + ) + self._inference_services.append(host) + host.start() + + def load(self, dir: str) -> None: + for conn in self._conn.values(): + conn.send({"type": "load", "dir": dir}) + + def choose_action(self, states_by_agent_id: Dict[str, np.ndarray]) -> dict: + pid2states, pid2aids = defaultdict(list), defaultdict(list) + for agent_id, state in states_by_agent_id.items(): + policy_id = self._aid2pid[agent_id] + pid2states[policy_id].append(state) + pid2aids[policy_id].append(agent_id) + + # send state batch to inference processes for parallelized inference. + for policy_id, conn in self._conn.items(): + if pid2states[policy_id]: + conn.send({"type": "choose_action", "states": np.vstack(pid2states[policy_id])}) + + action_by_agent = {} + for policy_id, conn in self._conn.items(): + if pid2states[policy_id]: + action_by_agent.update(zip(pid2aids[policy_id], conn.recv())) + + return action_by_agent + + def set_policy_states(self, states_by_policy_id: dict) -> None: + for policy_id, conn in self._conn.items(): + conn.send({"type": "set_state", "policy_state": states_by_policy_id[policy_id]}) + + def explore(self) -> None: + for conn in self._conn.values(): + conn.send({"type": "explore"}) + + def exploit(self) -> None: + for conn in self._conn.values(): + conn.send({"type": "exploit"}) + + def exploration_step(self) -> None: + for conn in self._conn.values(): + conn.send({"type": "exploration_step"}) + + def get_rollout_info(self) -> dict: + rollout_info = {} + for conn in self._conn.values(): + conn.send({"type": "rollout_info"}) + + for policy_id, conn in self._conn.items(): + info = conn.recv() + if info: + rollout_info[policy_id] = info + + return rollout_info + + def get_exploration_params(self) -> dict: + exploration_params = {} + for conn in self._conn.values(): + conn.send({"type": "exploration_params"}) + + for policy_id, conn in self._conn.items(): + params = conn.recv() + if params: + exploration_params[policy_id] = params + + return exploration_params + + def record_transition( + self, agent: str, state: np.ndarray, action: dict, reward: float, next_state: np.ndarray, terminal: bool + ) -> None: + self._conn[self._aid2pid[agent]].send({ + "type": "record", "agent": agent, "state": state, "action": action, "reward": reward, + "next_state": next_state, "terminal": terminal + }) + + def improve(self, checkpoint_dir: str = None) -> None: + for conn in self._conn.values(): + conn.send({"type": "improve", "checkpoint_dir": checkpoint_dir}) + + +class AbsEnvSampler(ABC): + """Simulation data collector and policy evaluator. + + Args: + get_env (Callable[[], Env]): Function to create an ``Env`` instance for collecting training data. The function + should take no parameters and return an environment wrapper instance. + get_policy_func_dict (dict): A dictionary mapping policy names to functions that create them. The policy + creation function should have policy name as the only parameter and return an ``AbsPolicy`` instance. + agent2policy (Dict[str, str]): A dictionary that maps agent IDs to policy IDs, i.e., specifies the policy used + by each agent. + get_test_env (Callable): Function to create an ``Env`` instance for testing policy performance. The function + should take no parameters and return an environment wrapper instance. If this is None, the training + environment wrapper will be used for evaluation in the worker processes. Defaults to None. + reward_eval_delay (int): Number of ticks required after a decision event to evaluate the reward + for the action taken for that event. Defaults to 0, which means rewards are evaluated immediately + after executing an action. + parallel_inference (bool): If True, the policies will be placed in separate processes so that inference can be + performed in parallel to speed up simulation. This is useful if some policies are big and take a long time + to generate actions. Defaults to False. + """ + def __init__( + self, + get_env: Callable[[], Env], + get_policy_func_dict: Dict[str, Callable], + agent2policy: Dict[str, str], + get_test_env: Callable[[], Env] = None, + reward_eval_delay: int = 0, + parallel_inference: bool = False + ) -> None: + self._learn_env = get_env() + self._test_env = get_test_env() if get_test_env else self._learn_env + self._env = None + + agent_wrapper_cls = ParallelAgentWrapper if parallel_inference else SimpleAgentWrapper + self._agent_wrapper: AbsAgentWrapper = agent_wrapper_cls(get_policy_func_dict, agent2policy) + + self._reward_eval_delay = reward_eval_delay + self._state = None + self._event = None + self._step_index = 0 + + self._transition_cache = defaultdict(deque) # for caching transitions whose rewards have yet to be evaluated + self._tracker = {} # User-defined tracking information is placed here. + + @property + def event(self) -> object: + return self._event + + @property + def agent_wrapper(self) -> AbsAgentWrapper: + return self._agent_wrapper + + @abstractmethod + def get_state(self, tick: int = None) -> dict: + """Compute the state for a given tick. + + Args: + tick (int): The tick for which to compute the environmental state. If computing the current state, + use tick=self.env.tick. + Returns: + A dictionary with (agent ID, state) as key-value pairs. + """ + raise NotImplementedError + + @abstractmethod + def get_env_actions(self, action) -> dict: + """Convert policy outputs to an action that can be executed by ``self.env.step()``.""" + raise NotImplementedError + + @abstractmethod + def get_reward(self, actions: list, tick: int) -> dict: + """Evaluate the reward for an action. + Args: + actions (list): Actions. + tick (int): Evaluate the reward for the actions that occured at the given tick. Each action in + ``actions`` must be an Action object defined for the environment in question. + + Returns: + A dictionary with (agent ID, reward) as key-value pairs. + """ + raise NotImplementedError + + def sample(self, policy_state_dict: dict = None, num_steps: int = -1, return_rollout_info: bool = True) -> dict: + self._env = self._learn_env + if not self._state: + # reset and get initial state + self._env.reset() + self._step_index = 0 + self._transition_cache.clear() + self._tracker.clear() + _, self._event, _ = self._env.step(None) + self._state = self.get_state() + + # set policy states + if policy_state_dict: + self._agent_wrapper.set_policy_states(policy_state_dict) + self._agent_wrapper.explore() + + starting_step_index = self._step_index + 1 + steps_to_go = float("inf") if num_steps == -1 else num_steps + while self._state and steps_to_go > 0: + action = self._agent_wrapper.choose_action(self._state) + env_actions = self.get_env_actions(action) + for agent, state in self._state.items(): + self._transition_cache[agent].append((state, action[agent], env_actions, self._env.tick)) + _, self._event, done = self._env.step(env_actions) + self._state = None if done else self.get_state() + self._step_index += 1 + steps_to_go -= 1 + + """ + If this is the final step, evaluate rewards for all remaining events except the last. + Otherwise, evaluate rewards only for events at least self.reward_eval_delay ticks ago. + """ + for agent, cache in self._transition_cache.items(): + while cache and (not self._state or self._env.tick - cache[0][-1] >= self._reward_eval_delay): + state, action, env_actions, tick = cache.popleft() + reward = self.get_reward(env_actions, tick) + self.post_step(state, action, env_actions, reward, tick) + self._agent_wrapper.record_transition( + agent, state, action, reward[agent], cache[0][0] if cache else self._state, + not cache and not self._state + ) + + result = { + "step_range": (starting_step_index, self._step_index), + "tracker": self._tracker, + "end_of_episode": not self._state, + "exploration_params": self._agent_wrapper.get_exploration_params() + } + if return_rollout_info: + result["rollout_info"] = self._agent_wrapper.get_rollout_info() + + if not self._state: + self._agent_wrapper.exploration_step() + return result + + def test(self, policy_state_dict: dict = None) -> dict: + self._env = self._test_env + # set policy states + if policy_state_dict: + self._agent_wrapper.set_policy_states(policy_state_dict) + + # Set policies to exploitation mode + self._agent_wrapper.exploit() + + self._env.reset() + terminal = False + # get initial state + _, self._event, _ = self._env.step(None) + state = self.get_state() + while not terminal: + action = self._agent_wrapper.choose_action(state) + env_actions = self.get_env_actions(action) + _, self._event, terminal = self._env.step(env_actions) + if not terminal: + state = self.get_state() + + return self._tracker + + @abstractmethod + def post_step(self, state: np.ndarray, action, env_actions, reward, tick): # TODO: argu type + """ + Gather any information you wish to track during a roll-out episode and store it in the ``tracker`` attribute. + """ + pass + + def worker( + self, + group: str, + index: int, + num_extra_recv_attempts: int = 0, + recv_timeout: int = 100, + proxy_kwargs: dict = None, + log_dir: str = getcwd() + ): + """Roll-out worker process that can be launched on separate computation nodes. + + Args: + group (str): Group name for the roll-out cluster, which includes all roll-out workers and a roll-out manager + that manages them. + index (int): Worker index. The worker's ID in the cluster will be "ROLLOUT_WORKER.{worker_idx}". + This is used for bookkeeping by the roll-out manager. + num_extra_recv_attempts (int): Number of extra receive attempts after each received ``SAMPLE`` message. This + is used to catch the worker up to the latest episode in case it trails the main learning loop by at + least one full episode. Defaults to 0. + recv_timeout (int): Timeout for the extra receive attempts. Defaults to 100 (miliseconds). + proxy_kwargs: Keyword parameters for the internal ``Proxy`` instance. See ``Proxy`` class + for details. Defaults to the empty dictionary. + log_dir (str): Directory to store logs in. Defaults to the current working directory. + """ + if proxy_kwargs is None: + proxy_kwargs = {} + + proxy = Proxy( + group, "rollout_worker", {"rollout_manager": 1}, component_name=f"ROLLOUT_WORKER.{index}", **proxy_kwargs + ) + logger = Logger(proxy.name, dump_folder=log_dir) + + """ + The event loop handles 3 types of messages from the roll-out manager: + 1) COLLECT, upon which the agent-environment simulation will be carried out for a specified number of steps + and the collected experiences will be sent back to the roll-out manager; + 2) EVAL, upon which the policies contained in the message payload will be evaluated for the entire + duration of the evaluation environment. + 3) EXIT, upon which it will break out of the event loop and the process will terminate. + + """ + while True: + msg = proxy.receive_once() + if msg.tag == MsgTag.EXIT: + logger.info("Exiting...") + proxy.close() + break + + if msg.tag == MsgTag.SAMPLE: + latest = msg + for _ in range(num_extra_recv_attempts): + msg = proxy.receive_once(timeout=recv_timeout) + if msg.body[MsgKey.EPISODE] > latest.body[MsgKey.EPISODE]: + logger.info(f"Skipped roll-out message for ep {latest.body[MsgKey.EPISODE]}") + latest = msg + + ep = latest.body[MsgKey.EPISODE] + result = self.sample( + policy_state_dict=latest.body[MsgKey.POLICY_STATE], num_steps=latest.body[MsgKey.NUM_STEPS] + ) + logger.info( + get_rollout_finish_msg(ep, result["step_range"], exploration_params=result["exploration_params"]) + ) + return_info = { + MsgKey.EPISODE: ep, + MsgKey.SEGMENT: latest.body[MsgKey.SEGMENT], + MsgKey.ROLLOUT_INFO: result["rollout_info"], + MsgKey.STEP_RANGE: result["step_range"], + MsgKey.TRACKER: result["tracker"], + MsgKey.END_OF_EPISODE: result["end_of_episode"] + } + proxy.reply(latest, tag=MsgTag.SAMPLE_DONE, body=return_info) + elif msg.tag == MsgTag.TEST: + tracker = self.test(msg.body[MsgKey.POLICY_STATE]) + return_info = {MsgKey.TRACKER: tracker, MsgKey.EPISODE: msg.body[MsgKey.EPISODE]} + logger.info("Testing complete") + proxy.reply(msg, tag=MsgTag.TEST_DONE, body=return_info) + + def actor( + self, + group: str, + index: int, + num_episodes: int, + num_steps: int = -1, + proxy_kwargs: dict = None, + log_dir: str = getcwd() + ): + """Controller for single-threaded learning workflows. + + Args: + group (str): Group name for the cluster that includes the server and all actors. + index (int): Integer actor index. The actor's ID in the cluster will be "ACTOR.{actor_idx}". + num_episodes (int): Number of training episodes. Each training episode may contain one or more + collect-update cycles, depending on how the implementation of the roll-out manager. + num_steps (int): Number of environment steps to roll out in each call to ``collect``. Defaults to -1, in + which case the roll-out will be executed until the end of the environment. + proxy_kwargs: Keyword parameters for the internal ``Proxy`` instance. See ``Proxy`` class + for details. Defaults to the empty dictionary. + log_dir (str): Directory to store logs in. A ``Logger`` with tag "LOCAL_ROLLOUT_MANAGER" will be created at + init time and this directory will be used to save the log files generated by it. Defaults to the current + working directory. + """ + if proxy_kwargs is None: + proxy_kwargs = {} + + if num_steps == 0 or num_steps < -1: + raise ValueError("num_steps must be a positive integer or -1") + + name = f"ACTOR.{index}" + logger = Logger(name, dump_folder=log_dir) + peers = {"policy_server": 1} + proxy = Proxy(group, "actor", peers, component_name=name, **proxy_kwargs) + server_address = proxy.peers["policy_server"][0] + + # get initial policy states from the policy manager + msg = SessionMessage(MsgTag.GET_INITIAL_POLICY_STATE, proxy.name, server_address) + reply = proxy.send(msg)[0] + policy_state_dict, policy_version = reply.body[MsgKey.POLICY_STATE], reply.body[MsgKey.VERSION] + + # main loop + for ep in range(1, num_episodes + 1): + while True: + result = self.sample(policy_state_dict=policy_state_dict, num_steps=num_steps) + logger.info( + get_rollout_finish_msg(ep, result["step_range"], exploration_params=result["exploration_params"]) + ) + # Send roll-out info to policy server for learning + reply = proxy.send( + SessionMessage( + MsgTag.SAMPLE_DONE, proxy.name, server_address, + body={MsgKey.ROLLOUT_INFO: result["rollout_info"], MsgKey.VERSION: policy_version} + ) + )[0] + policy_state_dict, policy_version = reply.body[MsgKey.POLICY_STATE], reply.body[MsgKey.VERSION] + if result["end_of_episode"]: + break + + # tell the policy server I'm all done. + proxy.isend(SessionMessage(MsgTag.DONE, proxy.name, server_address, session_type=SessionType.NOTIFICATION)) + proxy.close() diff --git a/maro/rl/learning/rollout_manager.py b/maro/rl/learning/rollout_manager.py index f70f53ddc..6ae99e578 100644 --- a/maro/rl/learning/rollout_manager.py +++ b/maro/rl/learning/rollout_manager.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from abc import ABC, abstractmethod +from abc import abstractmethod from collections import defaultdict from multiprocessing import Pipe, Process +from multiprocessing.connection import Connection from os import getcwd, getpid from random import choices from typing import Callable, Dict, List, Tuple @@ -18,15 +19,15 @@ from .helpers import get_rollout_finish_msg -def concat_batches(batch_list: List[dict]): +def concat_batches(batch_list: List[dict]) -> dict: return {key: np.concatenate([batch[key] for batch in batch_list]) for key in batch_list[0]} -class AbsRolloutManager(ABC): +class AbsRolloutManager(object): """Controller for simulation data collection.""" - def __init__(self): - super().__init__() - self.end_of_episode = False + def __init__(self) -> None: + super(AbsRolloutManager, self).__init__() + self._end_of_episode = False @abstractmethod def collect(self, ep: int, segment: int, policy_state_dict: dict) -> Tuple[Dict, List[Dict]]: @@ -45,8 +46,12 @@ def collect(self, ep: int, segment: int, policy_state_dict: dict) -> Tuple[Dict, """ raise NotImplementedError + @property + def end_of_episode(self) -> bool: + return self._end_of_episode + @abstractmethod - def evaluate(self, ep: int, policy_state_dict: dict): + def evaluate(self, ep: int, policy_state_dict: dict) -> list: """Evaluate policy performance. Args: @@ -59,10 +64,11 @@ def evaluate(self, ep: int, policy_state_dict: dict): """ raise NotImplementedError - def reset(self): - self.end_of_episode = False + def reset(self) -> None: + self._end_of_episode = False - def exit(self): + @abstractmethod + def exit(self) -> None: pass @@ -75,7 +81,7 @@ class MultiProcessRolloutManager(AbsRolloutManager): num_rollouts (int): Number of processes to spawn for parallel roll-out. num_steps (int): Number of environment steps to roll out in each call to ``collect``. Defaults to -1, in which case the roll-out will be executed until the end of the environment. - num_eval_rollout (int): Number of roll-out processes to use for evaluation. Defaults to 1. + num_eval_rollouts (int): Number of roll-out processes to use for evaluation. Defaults to 1. log_dir (str): Directory to store logs in. A ``Logger`` with tag "ROLLOUT_MANAGER" will be created at init time and this directory will be used to save the log files generated by it. Defaults to the current working directory. @@ -87,7 +93,9 @@ def __init__( num_steps: int = -1, num_eval_rollouts: int = 1, log_dir: str = getcwd() - ): + ) -> None: + super(MultiProcessRolloutManager, self).__init__() + if num_steps == 0 or num_steps < -1: raise ValueError("num_steps must be a positive integer or -1.") @@ -97,7 +105,6 @@ def __init__( if num_eval_rollouts > num_rollouts: raise ValueError("'num_eval_rollouts' can not be greater than 'num_rollouts'.") - super().__init__() self._logger = Logger("ROLLOUT_MANAGER", dump_folder=log_dir) self._num_steps = num_steps if num_steps > 0 else float("inf") self._num_rollouts = num_rollouts @@ -105,7 +112,7 @@ def __init__( self._worker_processes = [] self._manager_ends = [] - def _rollout_worker(index, conn, get_env_sampler): + def _rollout_worker(index: int, conn: Connection, get_env_sampler: Callable) -> None: set_seeds(index) env_sampler = get_env_sampler() logger = Logger("ROLLOUT_WORKER", dump_folder=log_dir) @@ -134,19 +141,6 @@ def _rollout_worker(index, conn, get_env_sampler): worker.start() def collect(self, ep: int, segment: int, policy_state_dict: dict) -> Tuple[Dict, List[Dict]]: - """Collect simulation data, i.e., experiences for training. - - Args: - ep (int): Current episode. - segment (int): Current segment. - policy_state_dict (dict): Policy states to use for collecting training info. - - Returns: - A 2-tuple consisting of a dictionary of roll-out information grouped by policy ID and a list of dictionaries - containing step-level information collected by the user-defined ``post_step`` callback in ``AbsEnvSampler``. - An RL policy's roll-out information must be either loss information or a data batch that can be passed to - the policy's ``update`` or ``learn``, respectively. - """ self._logger.info(f"Collecting simulation data (episode {ep}, segment {segment})") info_by_policy, trackers = defaultdict(list), [] @@ -165,26 +159,17 @@ def collect(self, ep: int, segment: int, policy_state_dict: dict) -> Tuple[Dict, for policy_id, info in result["rollout_info"].items(): info_by_policy[policy_id].append(info) trackers.append(result["tracker"]) - self.end_of_episode = result["end_of_episode"] + self._end_of_episode = result["end_of_episode"] # concat batches from different roll-out workers - for policy_id, info_list in info_by_policy.items(): + new_info_by_policy = {k: v for k, v in info_by_policy.items()} + for policy_id, info_list in new_info_by_policy.items(): if "loss" not in info_list[0]: - info_by_policy[policy_id] = concat_batches(info_list) + new_info_by_policy[policy_id] = concat_batches(info_list) - return info_by_policy, trackers + return new_info_by_policy, trackers - def evaluate(self, ep: int, policy_state_dict: dict): - """Evaluate policy performance. - - Args: - ep (int): Current training episode. - policy_state_dict (dict): Policy states to use for evaluation. - - Returns: - A list of dictionaries containing step-level information collected by the user-defined ``post_step`` - callback in ``AbsEnvSampler`` for evaluation purposes. - """ + def evaluate(self, ep: int, policy_state_dict: dict) -> list: trackers = [] eval_worker_conns = choices(self._manager_ends, k=self._num_eval_rollouts) for conn in eval_worker_conns: @@ -195,7 +180,7 @@ def evaluate(self, ep: int, policy_state_dict: dict): return trackers - def exit(self): + def exit(self) -> None: """Tell the worker processes to exit.""" for conn in self._manager_ends: conn.send({"type": "quit"}) @@ -231,15 +216,21 @@ def __init__( min_finished_workers: int = None, max_extra_recv_tries: int = 0, extra_recv_timeout: int = 100, - max_lag: Dict[str, int] = defaultdict(int), + max_lag: Dict[str, int] = None, num_eval_workers: int = 1, - proxy_kwargs: dict = {}, + proxy_kwargs: dict = None, log_dir: str = getcwd() - ): + ) -> None: + super(DistributedRolloutManager, self).__init__() + + if max_lag is None: + max_lag = defaultdict(int) + if proxy_kwargs is None: + proxy_kwargs = {} + if num_eval_workers > num_workers: raise ValueError("num_eval_workers cannot exceed the number of available workers") - super().__init__() self._num_workers = num_workers peers = {"rollout_worker": num_workers} self._proxy = Proxy(group, "rollout_manager", peers, component_name="ROLLOUT_MANAGER", **proxy_kwargs) @@ -264,19 +255,6 @@ def __init__( self._num_eval_workers = num_eval_workers def collect(self, ep: int, segment: int, policy_state_dict: dict) -> Tuple[Dict, List[Dict]]: - """Collect simulation data, i.e., experiences for training. - - Args: - ep (int): Current episode. - segment (int): Current segment. - policy_state_dict (dict): Policy states to use for collecting training info. - - Returns: - A 2-tuple consisting of a dictionary of roll-out information grouped by policy ID and a list of dictionaries - containing step-level information collected by the user-defined ``post_step`` callback in ``AbsEnvSampler``. - An RL policy's roll-out information must be either loss information or a data batch that can be passed to - the policy's ``update`` or ``learn``, respectively. - """ msg_body = { MsgKey.EPISODE: ep, MsgKey.SEGMENT: segment, @@ -315,13 +293,14 @@ def collect(self, ep: int, segment: int, policy_state_dict: dict) -> Tuple[Dict, break # concat batches from different roll-out workers - for policy_id, info_list in info_by_policy.items(): + new_info_by_policy = {k: v for k, v in info_by_policy.items()} + for policy_id, info_list in new_info_by_policy.items(): if "loss" not in info_list[0]: - info_by_policy[policy_id] = concat_batches(info_list) + new_info_by_policy[policy_id] = concat_batches(info_list) - return info_by_policy, trackers + return new_info_by_policy, trackers - def _handle_worker_result(self, msg, ep, segment): + def _handle_worker_result(self, msg, ep, segment) -> tuple: if msg.tag != MsgTag.SAMPLE_DONE: self._logger.info( f"Ignored a message of type {msg.tag} (expected message type {MsgTag.SAMPLE_DONE})" @@ -330,22 +309,12 @@ def _handle_worker_result(self, msg, ep, segment): # The message is what we expect if msg.body[MsgKey.EPISODE] == ep and msg.body[MsgKey.SEGMENT] == segment: - self.end_of_episode = msg.body[MsgKey.END_OF_EPISODE] + self._end_of_episode = msg.body[MsgKey.END_OF_EPISODE] return msg.body[MsgKey.ROLLOUT_INFO], msg.body[MsgKey.TRACKER] return None, None - def evaluate(self, ep: int, policy_state_dict: dict): - """Evaluate policy performance. - - Args: - ep (int): Current training episode. - policy_state_dict (dict): Policy states to use for evaluation. - - Returns: - A list of dictionaries containing step-level information collected by the user-defined ``post_step`` - callback in ``AbsEnvSampler`` for evaluation purposes. - """ + def evaluate(self, ep: int, policy_state_dict: dict) -> list: msg_body = {MsgKey.EPISODE: ep, MsgKey.POLICY_STATE: policy_state_dict} workers = choices(self._workers, k=self._num_eval_workers) @@ -371,7 +340,7 @@ def evaluate(self, ep: int, policy_state_dict: dict): return trackers - def exit(self): + def exit(self) -> None: """Tell the remote workers to exit.""" self._proxy.ibroadcast("rollout_worker", MsgTag.EXIT, SessionType.NOTIFICATION) self._proxy.close() diff --git a/maro/rl/modeling_v2/__init__.py b/maro/rl/modeling_v2/__init__.py new file mode 100644 index 000000000..d4d73ab1a --- /dev/null +++ b/maro/rl/modeling_v2/__init__.py @@ -0,0 +1,11 @@ +from .ac_network import DiscreteActorCriticNet, DiscreteVActorCriticNet +from .base_model import AbsCoreModel, PolicyNetwork +from .pg_network import DiscretePolicyGradientNetwork +from .q_network import DiscreteQNetwork, QNetwork + +__all__ = [ + "DiscreteActorCriticNet", "DiscreteVActorCriticNet", + "AbsCoreModel", "PolicyNetwork", + "DiscretePolicyGradientNetwork", + "DiscreteQNetwork", "QNetwork" +] diff --git a/maro/rl/modeling_v2/ac_network.py b/maro/rl/modeling_v2/ac_network.py new file mode 100644 index 000000000..a20aff426 --- /dev/null +++ b/maro/rl/modeling_v2/ac_network.py @@ -0,0 +1,63 @@ +from abc import ABCMeta +from typing import Optional, Tuple + +import torch +from torch.distributions import Categorical + +from .base_model import DiscreteProbPolicyNetworkMixin, PolicyNetwork +from .critic_model import CriticMixin, VCriticMixin + + +class DiscreteActorCriticNet(CriticMixin, DiscreteProbPolicyNetworkMixin, PolicyNetwork, metaclass=ABCMeta): + """ + Model framework for the actor-critic architecture. + + All concrete classes that inherit `DiscreteActorCriticNet` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `DiscreteProbPolicyNetworkMixin`: + - _get_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + """ + + def __init__(self, state_dim: int, action_num: int) -> None: + super(DiscreteActorCriticNet, self).__init__(state_dim=state_dim, action_dim=1) + self._action_num = action_num + + def _get_action_num(self) -> int: + return self._action_num + + def _get_actions_and_logps_exploring_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + action_probs = Categorical(self.get_probs(states)) + actions = action_probs.sample() + logps = action_probs.log_prob(actions) + return actions, logps + + def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + return self.get_actions_and_logps(states, exploring)[0].unsqueeze(1) + + def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool: + return self._policy_net_shape_check(states=states, actions=actions) + + +class DiscreteVActorCriticNet(VCriticMixin, DiscreteActorCriticNet, metaclass=ABCMeta): + """ + Model framework for the actor-critic architecture for finite and discrete action spaces. + + All concrete classes that inherit `DiscreteVActorCriticNet` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `DiscreteProbPolicyNetworkMixin`: + - _get_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + - Declared in `VCriticMixin`: + - _get_v_critic(self, states: torch.Tensor) -> torch.Tensor: + """ + def __init__(self, state_dim: int, action_num: int) -> None: + super(DiscreteVActorCriticNet, self).__init__(state_dim=state_dim, action_num=action_num) diff --git a/maro/rl/modeling_v2/base_model.py b/maro/rl/modeling_v2/base_model.py new file mode 100644 index 000000000..6b07a0f32 --- /dev/null +++ b/maro/rl/modeling_v2/base_model.py @@ -0,0 +1,320 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import abstractmethod +from typing import Optional, Tuple + +import torch + +from maro.rl.utils import match_shape + + +class AbsCoreModel(torch.nn.Module): + """ + The ancestor of all Torch models in MARO. + + All concrete classes that inherit `AbsCoreModel` should implement all abstract methods + declared in `AbsCoreModel`, includes: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + """ + + def __init__(self) -> None: + super(AbsCoreModel, self).__init__() + + @abstractmethod + def step(self, loss: torch.tensor) -> None: + """Use a computed loss to back-propagate gradients and apply them to the underlying parameters. + + Args: + loss: Result of a computation graph that involves the underlying parameters. + """ + pass + + @abstractmethod + def get_gradients(self, loss: torch.tensor) -> torch.tensor: + """Get gradients from a computed loss. + + There are two possible scenarios where you need to implement this interface: 1) if you are doing distributed + learning and want each roll-out instance to collect gradients that can be directly applied to policy parameters + on the learning side (abstracted through ``AbsPolicyManager``); 2) if you are computing loss in data-parallel + fashion, i.e., by splitting a data batch to several smaller batches and sending them to a set of remote workers + for parallelized gradient computation. In this case, this method will be used by the remote workers. + """ + pass + + @abstractmethod + def apply_gradients(self, grad: dict) -> None: + """Apply gradients to the model parameters. + + This needs to be implemented together with ``get_gradients``. + """ + pass + + @abstractmethod + def get_state(self) -> object: + """Return the current model state. + + Ths model state usually involves the "state_dict" of the module as well as those of the embedded optimizers. + """ + pass + + @abstractmethod + def set_state(self, state: object) -> None: + """Set model state. + + Args: + state: Model state to be applied to the instance. Ths model state is either the result of a previous call + to ``get_state`` or something loaded from disk and involves the "state_dict" of the module as well as those + of the embedded optimizers. + """ + pass + + def soft_update(self, other_model: torch.nn.Module, tau: float) -> None: + """Soft-update model parameters using another model. + + Update formulae: param = (1 - tau) * param + tau * other_param. + + Args: + other_model: The model to update the current model with. + tau (float): Soft-update coefficient. + """ + for params, other_params in zip(self.parameters(), other_model.parameters()): + params.data = (1 - tau) * params.data + tau * other_params.data + + +class SimpleNetwork(AbsCoreModel): + """ + Simple neural network that has one input and one output. + + `SimpleNetwork` does not contain any semantics and therefore can be used for any purpose. However, we recommend + users to use `PolicyNetwork` if the network is used for generating actions according to states. `PolicyNetwork` + has better supports for these functionalities. + + All concrete classes that inherit `SimpleNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `SimpleNetwork`: + - _forward_impl(self, x: torch.Tensor) -> torch.Tensor: + """ + def __init__(self, input_dim: int, output_dim: int) -> None: + super(SimpleNetwork, self).__init__() + self._input_dim = input_dim + self._output_dim = output_dim + + @property + def input_dim(self) -> int: + """Input dimension of the network.""" + return self._input_dim + + @property + def output_dim(self) -> int: + """Output dimension of the network.""" + return self._input_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._forward_impl(x) + + @abstractmethod + def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: + """ + The implementation that contains the actual logic of the network. Users should implement their own logics + in this method. + """ + pass + + +class ShapeCheckMixin: + """ + Mixin that contains the `_policy_net_shape_check` method, which is used for checking whether the states and actions + have valid shapes. Usually, it should contains three parts: + 1. Check of states' shape. + 2. Check of actions' shape. + 3. Check whether states and actions have identical batch sizes. + + `actions` is optional. If it is None, it means we do not need to check action related issues. + """ + @abstractmethod + def _policy_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool: + pass + + +class PolicyNetwork(ShapeCheckMixin, AbsCoreModel): + """ + Neural networks for policies. + + All concrete classes that inherit `PolicyNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `PolicyNetwork`: + - _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + """ + def __init__(self, state_dim: int, action_dim: int) -> None: + super(PolicyNetwork, self).__init__() + self._state_dim = state_dim + self._action_dim = action_dim + + @property + def state_dim(self) -> int: + """State dimension.""" + return self._state_dim + + @property + def action_dim(self) -> int: + """Action dimension""" + return self._action_dim + + def _policy_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool: + return all([ + states.shape[0] > 0 and match_shape(states, (None, self.state_dim)), + actions is None or (actions.shape[0] > 0 and match_shape(actions, (None, self.action_dim))), + actions is None or states.shape[0] == actions.shape[0] + ]) + + def get_actions(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + """ + Get actions according to the given states. The actual logics should be implemented in `_get_actions_impl`. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + exploring (bool): Get the actions under exploring mode (True) or exploiting mode (False). + + Returns: + actions (torch.Tensor) with shape [batch_size, action_dim]. + """ + assert self._policy_net_shape_check(states=states, actions=None) + ret = self._get_actions_impl(states, exploring) + assert match_shape(ret, (states.shape[0], self._action_dim)) + return ret + + @abstractmethod + def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + """ + Implementation of `get_actions`. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + exploring (bool): Get the actions under exploring mode (True) or exploiting mode (False). + + Returns: + actions (torch.Tensor) with shape [batch_size, action_dim]. + """ + pass + + +class DiscretePolicyNetworkMixin: + """ + Mixin for discrete policy networks. All policy networks that generate discrete actions should extend this mixin + and implemented all methods inherited from this mixin. + """ + @property + def action_num(self) -> int: + """ + Returns the number of actions. + """ + return self._get_action_num() + + @abstractmethod + def _get_action_num(self) -> int: + """ + Implementation of `action_num`. + """ + pass + + +class DiscreteProbPolicyNetworkMixin(DiscretePolicyNetworkMixin, ShapeCheckMixin): + """ + Mixin for discrete policy networks that have the concept of 'probability'. Policy networks that extend this mixin + should first calculate the probability for each potential action, and then choose the action according to the + probabilities. + + Notice: any concrete class that inherits `DiscreteProbPolicyNetworkMixin` should also implement + `_get_action_num(self) -> int:` defined in `DiscretePolicyNetworkMixin`. + """ + def get_probs(self, states: torch.Tensor) -> torch.Tensor: + """ + Get probabilities of all potential actions. The actual logics should be implemented in `_get_probs_impl`. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + + Returns: + probability matrix: [batch_size, action_num] + """ + self._policy_net_shape_check(states=states, actions=None) + ret = self._get_probs_impl(states) + assert match_shape(ret, (states.shape[0], self.action_num)) + return ret + + @abstractmethod + def _get_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + """ + Implementation of `get_probs`. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + + Returns: + probability matrix: [batch_size, action_num] + """ + pass + + def get_logps(self, states: torch.Tensor) -> torch.Tensor: + """ + Get log-probabilities of all possible actions. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + + Returns: + Log-probability matrix: [batch_size, action_num] + """ + return torch.log(self.get_probs(states)) + + def get_actions_and_logps(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get actions and corresponding log-probabilities. + + If under exploring mode (exploring = True), the actions shall be taken follow the logic implemented in + `_get_actions_and_logps_exploration_impl`. If under exploiting mode (exploring = False), the actions + will be taken through a greedy strategy (choose the action with the highest probability). + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + exploring (bool): `True` means under exploring mode. `False` means under exploiting mode. + + Returns: + Actions and log-P values, both with shape [batch_size]. + """ + if exploring: + actions, logps = self._get_actions_and_logps_exploring_impl(states) + else: + action_prob = self.get_logps(states) # [batch_size, num_actions] + logps, actions = action_prob.max(dim=1) + assert match_shape(actions, (states.shape[0],)) + assert match_shape(logps, (states.shape[0],)) + return actions, logps + + @abstractmethod + def _get_actions_and_logps_exploring_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get actions and corresponding log-probabilities under exploring mode. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + + Returns: + Actions and log-P values, both with shape [batch_size]. + """ + pass diff --git a/maro/rl/modeling_v2/critic_model.py b/maro/rl/modeling_v2/critic_model.py new file mode 100644 index 000000000..46520ae98 --- /dev/null +++ b/maro/rl/modeling_v2/critic_model.py @@ -0,0 +1,291 @@ +from abc import ABCMeta, abstractmethod +from typing import List, Optional + +import torch + +from maro.rl.modeling_v2.base_model import AbsCoreModel +from maro.rl.utils import match_shape + + +class CriticMixin: + """ + Mixin for all networks that used as critic models. + """ + @abstractmethod + def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool: + """ + Checks whether the states and actions + have valid shapes. Usually, it should contains three parts: + 1. Check of states' shape. + 2. Check of actions' shape. + 3. Check whether states and actions have identical batch sizes. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + actions (Optional[torch.Tensor]): Actions with shape [batch_size, action_dim] or None. If it is None, it + means we don't need to check action related issues. + + Returns: + Whether the states and actions have valid shapes + """ + pass + + +class VCriticMixin(CriticMixin): + """ + Mixin for all networks that used as V-value based critic models. + + All concrete classes that inherit `VCriticMixin` should implement the following abstract methods: + - Declared in `CriticMixin`: + - _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool: + - Declared in `VCriticMixin`: + - _get_v_critic(self, states: torch.Tensor) -> torch.Tensor: + """ + def v_critic(self, states: torch.Tensor) -> torch.Tensor: + """ + Get V-values of the given states. The actual logics should be implemented in `_get_v_critic`. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + + Returns: + V-values for critic with shape [batch_size] + """ + assert self._critic_net_shape_check(states=states, actions=None) + ret = self._get_v_critic(states) + assert match_shape(ret, (states.shape[0],)) + return ret + + @abstractmethod + def _get_v_critic(self, states: torch.Tensor) -> torch.Tensor: + """Implementation of v_critic.""" + pass + + +class QCriticMixin(CriticMixin): + """ + Mixin for all networks that used as Q-value based critic models. + + All concrete classes that inherit `QCriticMixin` should implement the following abstract methods: + - Declared in `CriticMixin`: + - _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool: + - Declared in `QCriticMixin`: + - _get_q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """ + def q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """ + Get Q-values according to the given states and actions. + The actual logics should be implemented in `q_critic`. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + actions (torch.Tensor): Actions with shape [batch_size, action_dim]. + + Returns: + Q-values for critic with shape [batch_size] + """ + assert self._critic_net_shape_check(states=states, actions=actions) + ret = self._get_q_critic(states, actions) + assert match_shape(ret, (states.shape[0],)) + return ret + + @abstractmethod + def _get_q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """Implementation of q_critic.""" + pass + + +class CriticNetwork(AbsCoreModel, metaclass=ABCMeta): + """ + Neural networks for critic models. + + All concrete classes that inherit `CriticNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + """ + def __init__(self, state_dim: int) -> None: + super(CriticNetwork, self).__init__() + self._state_dim = state_dim + + @property + def state_dim(self) -> int: + return self._state_dim + + def _is_valid_state_shape(self, states: torch.Tensor) -> bool: + return states.shape[0] > 0 and match_shape(states, (None, self.state_dim)) + + +class VCriticNetwork(VCriticMixin, CriticNetwork, metaclass=ABCMeta): + """ + Neural networks for V-value based critic models. + + All concrete classes that inherit `VCriticNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `VCriticMixin`: + - _get_v_critic(self, states: torch.Tensor) -> torch.Tensor: + """ + def __init__(self, state_dim: int) -> None: + super(VCriticNetwork, self).__init__(state_dim=state_dim) + + def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool: + return self._is_valid_state_shape(states) + + +class QCriticNetwork(QCriticMixin, CriticNetwork, metaclass=ABCMeta): + """ + Neural networks for Q-value based critic models. + + All concrete classes that inherit `QCriticNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `QCriticMixin`: + - _get_q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """ + def __init__(self, state_dim: int, action_dim: int) -> None: + super(QCriticNetwork, self).__init__(state_dim=state_dim) + self._action_dim = action_dim + + @property + def action_dim(self) -> int: + return self._action_dim + + def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool: + return all([ + self._is_valid_state_shape(states), + self._is_valid_action_shape(actions), + states.shape[0] == actions.shape[0], + ]) + + def _is_valid_action_shape(self, actions: torch.Tensor) -> bool: + return actions.shape[0] > 0 and match_shape(actions, (None, self.action_dim)) + + +class DiscreteQCriticNetwork(QCriticNetwork): + """ + Neural networks for Q-value based critic models that take discrete actions as inputs. + + All concrete classes that inherit `DiscreteQCriticNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `QCriticMixin`: + - _get_q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + - Declared in `DiscreteQCriticNetwork`: + - _get_q_critic_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + """ + def __init__(self, state_dim: int, action_num: int) -> None: + super(DiscreteQCriticNetwork, self).__init__(state_dim=state_dim, action_dim=1) + self._action_num = action_num + + def _get_q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + q_matrix = self.q_critic_for_all_actions(states) # [batch_size, action_num] + actions = actions.unsqueeze(dim=1) + return q_matrix.gather(dim=1, index=actions).reshape(-1) + + @property + def action_num(self) -> int: + return self._action_num + + def q_critic_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + """ + Generates the matrix that contains the Q-values for all potential actions. + The actual logics should be implemented in `_get_q_critic_for_all_actions`. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + + Returns: + q critics for all actions with shape [batch_size, action_num] + """ + assert self._is_valid_state_shape(states) + ret = self._get_q_critic_for_all_actions(states) + assert match_shape(ret, (states.shape[0], self.action_num)) + return ret + + @abstractmethod + def _get_q_critic_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + """Implementation of `q_critic_for_all_actions`""" + pass + + +class MultiQCriticNetwork(QCriticMixin, CriticNetwork, metaclass=ABCMeta): + """ + Neural networks for Q-value based critic models that takes multiple actions as inputs. + This is used for multi-agent RL scenarios. + + All concrete classes that inherit `MultiQCriticNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `QCriticMixin`: + - _get_q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """ + def __init__(self, state_dim: int, action_dim: int, agent_num: int) -> None: + super(MultiQCriticNetwork, self).__init__(state_dim=state_dim) + self._action_dim = action_dim + self._agent_num = agent_num + + @property + def action_dim(self) -> int: + return self._action_dim + + @property + def agent_num(self) -> int: + return self._agent_num + + def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool: + return all([ + self._is_valid_state_shape(states), + actions is None or self._is_valid_action_shape(actions), + actions is None or states.shape[0] == actions.shape[0] + ]) + + def _is_valid_action_shape(self, actions: torch.Tensor) -> bool: + return match_shape(actions, (None, self.agent_num, self.action_dim)) + + +class MultiDiscreteQCriticNetwork(MultiQCriticNetwork, metaclass=ABCMeta): + """ + Neural networks for Q-value based critic models that take multiple discrete actions as inputs. + + All concrete classes that inherit `MultiDiscreteQCriticNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `QCriticMixin`: + - _get_q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """ + def __init__(self, state_dim: int, action_nums: List[int]) -> None: + super(MultiDiscreteQCriticNetwork, self).__init__(state_dim=state_dim, action_dim=1, agent_num=len(action_nums)) + self._action_nums = action_nums + self._agent_num = len(action_nums) + + @property + def action_nums(self) -> List[int]: + return self._action_nums + + @property + def agent_num(self) -> int: + return self._agent_num diff --git a/maro/rl/modeling_v2/pg_network.py b/maro/rl/modeling_v2/pg_network.py new file mode 100644 index 000000000..890b2c48e --- /dev/null +++ b/maro/rl/modeling_v2/pg_network.py @@ -0,0 +1,39 @@ +from abc import ABCMeta +from typing import Tuple + +import torch +from torch.distributions import Categorical + +from .base_model import DiscreteProbPolicyNetworkMixin, PolicyNetwork + + +class DiscretePolicyGradientNetwork(DiscreteProbPolicyNetworkMixin, PolicyNetwork, metaclass=ABCMeta): + """ + Model framework for the policy gradient networks. + + All concrete classes that inherit `DiscretePolicyGradientNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `DiscreteProbPolicyNetworkMixin`: + - _get_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + """ + + def __init__(self, state_dim: int, action_num: int) -> None: + super(DiscretePolicyGradientNetwork, self).__init__(state_dim, 1) + self._action_num = action_num + + def _get_actions_and_logps_exploring_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + action_probs = Categorical(self.get_probs(states)) + actions = action_probs.sample() + logps = action_probs.log_prob(actions) + return actions, logps + + def _get_action_num(self) -> int: + return self._action_num + + def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + return self.get_actions_and_logps(states, exploring)[0].unsqueeze(1) diff --git a/maro/rl/modeling_v2/q_network.py b/maro/rl/modeling_v2/q_network.py new file mode 100644 index 000000000..f40d6eeb1 --- /dev/null +++ b/maro/rl/modeling_v2/q_network.py @@ -0,0 +1,103 @@ +from abc import abstractmethod + +import torch + +from maro.rl.utils import match_shape + +from .base_model import DiscretePolicyNetworkMixin, PolicyNetwork + + +class QNetwork(PolicyNetwork): + """ + Q-network for value-based policies. The action could be either continuous or discrete. + + All concrete classes that inherit `QNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `QNetwork`: + - _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """ + + def __init__(self, state_dim: int, action_dim: int) -> None: + super(QNetwork, self).__init__(state_dim=state_dim, action_dim=action_dim) + + def q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """ + Return the Q-values according to states and actions. + The actual logics should be implemented in `_get_q_values`. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + actions (torch.Tensor) : Actions with shape [batch_size, action_dim]. + + Returns: + Q-values with shape [batch_size]. + """ + assert self._policy_net_shape_check(states=states, actions=actions) + ret = self._get_q_values(states, actions) + assert match_shape(ret, (states.shape[0],)) + return ret + + @abstractmethod + def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """Implementation of `q_values`.""" + pass + + +class DiscreteQNetwork(DiscretePolicyNetworkMixin, QNetwork): + """ + Q-network for discrete value-based policies. + + All concrete classes that inherit `DiscreteQNetwork` should implement the following abstract methods: + - Declared in `AbsCoreModel`: + - step(self, loss: torch.tensor) -> None: + - get_gradients(self, loss: torch.tensor) -> torch.tensor: + - apply_gradients(self, grad: dict) -> None: + - get_state(self) -> object: + - set_state(self, state: object) -> None: + - Declared in `DiscreteQNetwork`: + - _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + """ + def __init__(self, state_dim: int, action_num: int) -> None: + super(DiscreteQNetwork, self).__init__(state_dim=state_dim, action_dim=1) + self._action_num = action_num + + def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + q_matrix = self.q_values_for_all_actions(states) # [batch_size, action_num] + return q_matrix.gather(dim=1, index=actions).reshape(-1) + + def q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + """" + Generates the matrix that contains the Q-values for all potential actions. + The actual logics should be implemented in `_get_q_values_for_all_actions`. + + Args: + states (torch.Tensor): States with shape [batch_size, state_dim]. + + Returns: + q values for all actions with shape [batch_size, action_num] + """ + assert self._policy_net_shape_check(states=states, actions=None) + ret = self._get_q_values_for_all_actions(states) + assert match_shape(ret, (states.shape[0], self.action_num)) + return ret + + @abstractmethod + def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + """Implementation of `q_values_for_all_actions`.""" + pass + + def _get_action_num(self) -> int: + return self._action_num + + def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + if exploring: # The exploring logic should be handles out of the network. + raise NotImplementedError + else: + q_matrix = self.q_values_for_all_actions(states) + _, action = q_matrix.max(dim=1) + return action.unsqueeze(1) diff --git a/maro/rl/policy_v2/__init__.py b/maro/rl/policy_v2/__init__.py new file mode 100644 index 000000000..c38bb6268 --- /dev/null +++ b/maro/rl/policy_v2/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .ac import DiscreteActorCritic +from .dqn import DQN, PrioritizedExperienceReplay +from .pg import DiscretePolicyGradient +from .policy_base import AbsPolicy, DummyPolicy, RLPolicy, RuleBasedPolicy + +__all__ = [ + "DiscreteActorCritic", + "DQN", "PrioritizedExperienceReplay", + "DiscretePolicyGradient", + "AbsPolicy", "DummyPolicy", "RLPolicy", "RuleBasedPolicy" +] diff --git a/maro/rl/policy_v2/ac.py b/maro/rl/policy_v2/ac.py new file mode 100644 index 000000000..13dfbf91a --- /dev/null +++ b/maro/rl/policy_v2/ac.py @@ -0,0 +1,248 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from collections import defaultdict +from typing import Callable, List, Tuple + +import numpy as np +import torch +from torch.distributions import Categorical + +from maro.communication import SessionMessage +from maro.rl.modeling_v2 import DiscreteVActorCriticNet +from maro.rl.utils import MsgKey, MsgTag, average_grads, discount_cumsum + +from .buffer import Buffer +from .policy_base import RLPolicy +from .policy_interfaces import DiscreteActionMixin, VNetworkMixin + + +class DiscreteActorCritic(VNetworkMixin, DiscreteActionMixin, RLPolicy): + """ + Actor Critic algorithm with separate policy and value models. + + References: + https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch. + https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f + + Args: + name (str): Unique identifier for the policy. + ac_net (DiscreteACNet): Multi-task model that computes action distributions and state values. + reward_discount (float): Reward decay as defined in standard RL terminology. + grad_iters (int): Number of gradient steps for each batch or set of batches. Defaults to 1. + critic_loss_cls: A string indicating a loss class provided by torch.nn or a custom loss class for computing + the critic loss. If it is a string, it must be a key in ``TORCH_LOSS``. Defaults to "mse". + min_logp (float): Lower bound for clamping logP values during learning. This is to prevent logP from becoming + very large in magnitude and causing stability issues. Defaults to None, which means no lower bound. + critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 1.0. + entropy_coef (float): Coefficient for the entropy term in total loss. Defaults to None, in which case the + total loss will not include an entropy term. + clip_ratio (float): Clip ratio in the PPO algorithm (https://arxiv.org/pdf/1707.06347.pdf). Defaults to None, + in which case the actor loss is calculated using the usual policy gradient theorem. + lam (float): Lambda value for generalized advantage estimation (TD-Lambda). Defaults to 0.9. + max_trajectory_len (int): Maximum trajectory length that can be held by the buffer (for each agent that uses + this policy). Defaults to 10000. + get_loss_on_rollout (bool): If True, ``get_rollout_info`` will return the loss information (including gradients) + for the trajectories stored in the buffers. The loss information, along with that from other roll-out + instances, can be passed directly to ``update``. Otherwise, it will simply process the trajectories into a + single data batch that can be passed directly to ``learn``. Defaults to False. + device (str): Identifier for the torch device. The ``ac_net`` will be moved to the specified device. If it is + None, the device will be set to "cpu" if cuda is unavailable and "cuda" otherwise. Defaults to None. + """ + def __init__( + self, + name: str, + ac_net: DiscreteVActorCriticNet, + reward_discount: float, + grad_iters: int = 1, + critic_loss_cls: Callable = None, + min_logp: float = None, + critic_loss_coef: float = 1.0, + entropy_coef: float = .0, + clip_ratio: float = None, + lam: float = 0.9, + max_trajectory_len: int = 10000, + get_loss_on_rollout: bool = False, + device: str = None + ) -> None: + if not isinstance(ac_net, DiscreteVActorCriticNet): + raise TypeError("model must be an instance of 'DiscreteVActorCriticNet'") + + super(DiscreteActorCritic, self).__init__(name=name, device=device) + + self._ac_net = ac_net.to(self._device) + self._reward_discount = reward_discount + self._grad_iters = grad_iters + self._critic_loss_func = critic_loss_cls() if critic_loss_cls is not None else torch.nn.MSELoss() + self._min_logp = min_logp + self._critic_loss_coef = critic_loss_coef + self._entropy_coef = entropy_coef + self._clip_ratio = clip_ratio + self._lam = lam + self._max_trajectory_len = max_trajectory_len + self._get_loss_on_rollout = get_loss_on_rollout + + self._buffer = defaultdict(lambda: Buffer(size=self._max_trajectory_len)) + + def _call_impl(self, states: np.ndarray) -> List[dict]: + """Return a list of action information dict given a batch of states. + + An action information dict contains the action itself, the corresponding log-P value and the corresponding + state value. + """ + actions, logps, values = self.get_actions_with_logps_and_values(states) + return [ + {"action": action, "logp": logp, "value": value} for action, logp, value in zip(actions, logps, values) + ] + + def get_actions_with_logps_and_values(self, states: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + self._ac_net.eval() + states = torch.from_numpy(states).to(self._device) + if len(states.shape) == 1: + states = states.unsqueeze(dim=0) + with torch.no_grad(): + actions, logps = self._ac_net.get_actions_and_logps(states, exploring=self._exploring) + values = self._get_v_critic(states) + actions, logps, values = actions.cpu().numpy(), logps.cpu().numpy(), values.cpu().numpy() + return actions, logps, values + + def _get_v_critic(self, states: torch.Tensor) -> torch.Tensor: + return self._ac_net.v_critic(states) + + def _get_v_values(self, states: np.ndarray) -> np.ndarray: + return self._get_v_critic(torch.Tensor(states)).numpy() + + def learn_with_data_parallel(self, batch: dict, worker_id_list: list) -> None: + assert hasattr(self, '_proxy'), "learn_with_data_parallel is invalid before data_parallel is called." + for _ in range(self._grad_iters): + msg_dict = defaultdict(lambda: defaultdict(dict)) + sub_batch = {} + for i, worker_id in enumerate(worker_id_list): + sub_batch = {key: batch[key][i::len(worker_id_list)] for key in batch} + msg_dict[worker_id][MsgKey.GRAD_TASK][self._name] = sub_batch + msg_dict[worker_id][MsgKey.POLICY_STATE][self._name] = self.get_state() + # data-parallel + self._proxy.isend(SessionMessage( + MsgTag.COMPUTE_GRAD, self._proxy.name, worker_id, body=msg_dict[worker_id])) + dones = 0 + loss_info_by_policy = {self._name: []} + for msg in self._proxy.receive(): + if msg.tag == MsgTag.COMPUTE_GRAD_DONE: + for policy_name, loss_info in msg.body[MsgKey.LOSS_INFO].items(): + if isinstance(loss_info, list): + loss_info_by_policy[policy_name] += loss_info + elif isinstance(loss_info, dict): + loss_info_by_policy[policy_name].append(loss_info) + else: + raise TypeError(f"Wrong type of loss_info: {type(loss_info)}") + dones += 1 + if dones == len(msg_dict): + break + # build dummy computation graph by `get_batch_loss` before apply gradients. + _ = self.get_batch_loss(sub_batch, explicit_grad=True) + self.update(loss_info_by_policy[self._name]) + + def _get_action_num(self) -> int: + return self._ac_net.action_num + + def _get_state_dim(self) -> int: + return self._ac_net.state_dim + + def record( + self, + key: str, + state: np.ndarray, + action: dict, + reward: float, + next_state: np.ndarray, + terminal: bool + ) -> None: + self._buffer[key].put(state, action, reward, terminal) + + def get_rollout_info(self) -> dict: + if self._get_loss_on_rollout: + return self.get_batch_loss(self._get_batch(), explicit_grad=True) + else: + return self._get_batch() + + def _get_batch(self) -> dict: + batch = defaultdict(list) + for buf in self._buffer.values(): + trajectory = buf.get() + values = np.append(trajectory["values"], trajectory["last_value"]) + rewards = np.append(trajectory["rewards"], trajectory["last_value"]) + deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] + batch["states"].append(trajectory["states"]) + batch["actions"].append(trajectory["actions"]) + # Returns rewards-to-go, to be targets for the value function + batch["returns"].append(discount_cumsum(rewards, self._reward_discount)[:-1]) + # Generalized advantage estimation using TD(Lambda) + batch["advantages"].append(discount_cumsum(deltas, self._reward_discount * self._lam)) + batch["logps"].append(trajectory["logps"]) + + return {key: np.concatenate(vals) for key, vals in batch.items()} + + def get_batch_loss(self, batch: dict, explicit_grad: bool = False) -> dict: + self._ac_net.train() + states = torch.from_numpy(batch["states"]).to(self._device) + actions = torch.from_numpy(batch["actions"]).to(self._device).long() + logp_old = torch.from_numpy(batch["logps"]).to(self._device) + returns = torch.from_numpy(batch["returns"]).to(self._device) + advantages = torch.from_numpy(batch["advantages"]).to(self._device) + + action_probs = self._ac_net.get_probs(states) + state_values = self._get_v_critic(states) + state_values = state_values.squeeze() + + # actor loss + logp = torch.log(action_probs.gather(1, actions.unsqueeze(1)).squeeze()) # (N,) + logp = torch.clamp(logp, min=self._min_logp, max=.0) + if self._clip_ratio is not None: + ratio = torch.exp(logp - logp_old) + clipped_ratio = torch.clamp(ratio, 1 - self._clip_ratio, 1 + self._clip_ratio) + actor_loss = -(torch.min(ratio * advantages, clipped_ratio * advantages)).mean() + else: + actor_loss = -(logp * advantages).mean() + + # critic_loss + critic_loss = self._critic_loss_func(state_values, returns) + # entropy + entropy = -Categorical(action_probs).entropy().mean() if self._entropy_coef else 0 + + # total loss + loss = actor_loss + self._critic_loss_coef * critic_loss + self._entropy_coef * entropy + + loss_info = { + "actor_loss": actor_loss.detach().cpu().numpy(), + "critic_loss": critic_loss.detach().cpu().numpy(), + "entropy": entropy.detach().cpu().numpy() if self._entropy_coef else .0, + "loss": loss.detach().cpu().numpy() if explicit_grad else loss + } + if explicit_grad: + loss_info["grad"] = self._ac_net.get_gradients(loss) + + return loss_info + + def data_parallel(self, *args, **kwargs) -> None: + pass # TODO + + def update(self, loss_info_list: List[dict]) -> None: + self._ac_net.apply_gradients(average_grads([loss_info["grad"] for loss_info in loss_info_list])) + + def learn(self, batch: dict) -> None: + for _ in range(self._grad_iters): + self._ac_net.step(self.get_batch_loss(batch)["loss"]) + + def improve(self) -> None: + self.learn(self._get_batch()) + + def get_state(self) -> object: + return self._ac_net.get_state() + + def set_state(self, policy_state: object) -> None: + self._ac_net.set_state(policy_state) + + def load(self, path: str) -> None: + self._ac_net.set_state(torch.load(path)) + + def save(self, path: str) -> None: + torch.save(self._ac_net.get_state(), path) diff --git a/maro/rl/policy_v2/buffer.py b/maro/rl/policy_v2/buffer.py new file mode 100644 index 000000000..076526413 --- /dev/null +++ b/maro/rl/policy_v2/buffer.py @@ -0,0 +1,118 @@ +import collections +from dataclasses import dataclass +from typing import Deque + +import numpy as np + + +@dataclass +class BufferElement: + state: np.ndarray + action: int + logp: float + value: float + reward: float + terminal: bool + + +@dataclass +class MultiBufferElement: + state: np.ndarray + actions: np.ndarray + logps: np.ndarray + value: float + reward: float + terminal: bool + + +class Buffer: + """Store a sequence of transitions, i.e., a trajectory. + + Args: + size (int): Buffer capacity, i.e., the maximum number of stored transitions. + """ + def __init__(self, size: int = 10000) -> None: + self._pool: Deque[BufferElement] = collections.deque() + self._size = size + + def put(self, state: np.ndarray, action: dict, reward: float, terminal: bool = False) -> None: + self._pool.append( + BufferElement( + state=state.reshape(1, -1), + action=action.get("action", 0), + logp=action.get("logp", 0.0), + value=action.get("value", 0.0), + reward=reward, + terminal=terminal + ) + ) + if len(self._pool) > self._size: + self._pool.popleft() + # TODO: erase the older elements or raise MLE error? + + def get(self) -> dict: + """Retrieve the latest trajectory segment.""" + if len(self._pool) == 0: + return {} + + new_pool = collections.deque() + if not self._pool[-1].terminal: + new_pool.append(self._pool.pop()) + + ret = { + "states": np.concatenate([elem.state for elem in self._pool], axis=0), + "actions": np.array([elem.action for elem in self._pool], dtype=np.int32), + "logps": np.array([elem.logp for elem in self._pool], dtype=np.float32), + "values": np.array([elem.value for elem in self._pool], dtype=np.float32), + "rewards": np.array([elem.reward for elem in self._pool], dtype=np.float32), + "last_value": self._pool[-1].value + } + + self._pool = new_pool + return ret + + +class MultiBuffer: + """TODO + """ + def __init__(self, agent_num: int, size: int = 10000) -> None: + self._pool: Deque[MultiBufferElement] = collections.deque() + self._agent_num = agent_num + self._size = size + + def put(self, state: np.ndarray, action: dict, reward: float, terminal: bool = False) -> None: + self._pool.append( + MultiBufferElement( + state=state.reshape(1, -1), + actions=np.array(action.get("action", [0] * self._agent_num)).reshape(1, -1), + logps=np.array(action.get("logp", [0.0] * self._agent_num)).reshape(1, -1), + value=action.get("value", 0.0), + reward=reward, + terminal=terminal + ) + ) + if len(self._pool) > self._size: + self._pool.popleft() + # TODO: erase the older elements or raise MLE error? + + def get(self) -> dict: + """Retrieve the latest trajectory segment.""" + if len(self._pool) == 0: + return {} + + new_pool = collections.deque() + if not self._pool[-1].terminal: + new_pool.append(self._pool.pop()) + + ret = { + "states": np.concatenate([elem.state for elem in self._pool], axis=0), # [batch_size, state_dim] + "actions": list(np.concatenate([elem.actions for elem in self._pool], axis=0).T), # list of [batch_size] + "logps": np.concatenate([elem.logps for elem in self._pool], axis=0), # [batch_size, agent_num] + "values": np.array([elem.value for elem in self._pool], dtype=np.float32), # [batch_size] + "rewards": np.array([elem.reward for elem in self._pool], dtype=np.float32), # [batch_size] + "terminals": np.array([elem.terminal for elem in self._pool], dtype=np.bool), # [batch_size] + "last_value": self._pool[-1].value # Scalar + } + + self._pool = new_pool + return ret diff --git a/maro/rl/policy_v2/dqn.py b/maro/rl/policy_v2/dqn.py new file mode 100644 index 000000000..3bc7ca969 --- /dev/null +++ b/maro/rl/policy_v2/dqn.py @@ -0,0 +1,461 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import defaultdict +from typing import Callable, Iterable, List, Tuple, Union + +import numpy as np +import torch + +from maro.communication import SessionMessage +from maro.rl.exploration import epsilon_greedy +from maro.rl.modeling_v2 import DiscreteQNetwork +from maro.rl.utils import MsgKey, MsgTag, average_grads +from maro.utils import clone + +from .policy_base import RLPolicy +from .policy_interfaces import DiscreteQNetworkMixin +from .replay import ReplayMemory + + +class PrioritizedExperienceReplay: + """Prioritized Experience Replay (PER). + + References: + https://arxiv.org/pdf/1511.05952.pdf + https://github.com/rlcode/per + + The implementation here is based on direct proportional prioritization (the first variant in the paper). + The rank-based variant is not implemented here. + + Args: + replay_memory (ReplayMemory): experience manager the sampler is associated with. + alpha (float): Prioritization strength. Sampling probabilities are calculated according to + P = p_i^alpha / sum(p_k^alpha). Defaults to 0.6. + beta (float): Bias annealing strength using weighted importance sampling (IS) techniques. + IS weights are calculated according to (N * P)^(-beta), where P is the sampling probability. + This value of ``beta`` should not exceed 1.0, which corresponds to full annealing. Defaults to 0.4. + beta_step (float): The amount ``beta`` is incremented by after each get() call until it reaches 1.0. + Defaults to 0.001. + max_priority (float): Maximum priority value to use for new experiences. Defaults to 1e8. + """ + def __init__( + self, + replay_memory: ReplayMemory, + *, + alpha: float = 0.6, + beta: float = 0.4, + beta_step: float = 0.001, + max_priority: float = 1e8 + ) -> None: + if beta > 1.0: + raise ValueError("beta should be between 0.0 and 1.0") + self._replay_memory = replay_memory + self._sum_tree = np.zeros(2 * self._replay_memory.capacity - 1) + self._alpha = alpha + self._beta = beta + self._beta_step = beta_step + self._eps = 1e-7 + self._max_priority = max_priority + + def total(self) -> float: + """Return the sum of priorities over all experiences.""" + return self._sum_tree[0] + + def set_max_priority(self, indexes: List[int]) -> None: + """Set the priorities of newly added experiences to the maximum value.""" + self.update(indexes, [self._max_priority] * len(indexes)) + + def update(self, indexes: List[int], td_errors: List[float]) -> None: + """Update priority values at given indexes.""" + for idx, err in zip(indexes, td_errors): + priority = self._get_priority(err) + tree_idx = idx + self._replay_memory.capacity - 1 + delta = priority - self._sum_tree[tree_idx] + self._sum_tree[tree_idx] = priority + self._update(tree_idx, delta) + + def sample(self, size: int) -> Tuple[List[int], float]: + """Priority-based sampling.""" + indexes, priorities = [], [] + segment_len = self.total() / size + for i in range(size): + low, high = segment_len * i, segment_len * (i + 1) + sampled_val = np.random.uniform(low=low, high=high) + idx = self._get(0, sampled_val) + data_idx = idx - self._replay_memory.capacity + 1 + indexes.append(data_idx) + priorities.append(self._sum_tree[idx]) + + self._beta = min(1., self._beta + self._beta_step) + sampling_probabilities = np.array(priorities) / (self.total() + 1e-8) + is_weights = np.power(self._replay_memory.size * sampling_probabilities, -self._beta) + is_weights /= (is_weights.max() + 1e-8) + + return indexes, is_weights + + def _get_priority(self, error: Union[torch.Tensor, np.ndarray, float]) -> float: + if isinstance(error, torch.Tensor): + error = error.detach().numpy() + return (np.abs(error) + self._eps) ** self._alpha + + def _update(self, idx: int, delta: float) -> None: + """Propagate priority change all the way to the root node.""" + parent = (idx - 1) // 2 + self._sum_tree[parent] += delta + if parent != 0: + self._update(parent, delta) + + def _get(self, idx: int, sampled_val: float) -> int: + """Get a leaf node according to a randomly sampled value.""" + left = 2 * idx + 1 + right = left + 1 + + if left >= len(self._sum_tree): + return idx + + if sampled_val <= self._sum_tree[left]: + return self._get(left, sampled_val) + else: + return self._get(right, sampled_val - self._sum_tree[left]) + + +class DQN(DiscreteQNetworkMixin, RLPolicy): + """The Deep-Q-Networks algorithm. + + See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details. + + Args: + name (str): Unique identifier for the policy. + q_net (DiscreteQNet): Q-value model. + reward_discount (float): Reward decay as defined in standard RL terminology. + num_epochs (int): Number of training epochs per call to ``learn``. Defaults to 1. + update_target_every (int): Number of gradient steps between target model updates. + soft_update_coef (float): Soft update coefficient, e.g., + target_model = (soft_update_coeff) * eval_model + (1-soft_update_coeff) * target_model. + Defaults to 1.0. + double (bool): If True, the next Q values will be computed according to the double DQN algorithm, + i.e., q_next = Q_target(s, argmax(Q_eval(s, a))). Otherwise, q_next = max(Q_target(s, a)). + See https://arxiv.org/pdf/1509.06461.pdf for details. Defaults to False. + exploration_strategy (Tuple[Callable, dict]): A 2-tuple that consists of a) a function that takes a state + (single or batch), an action (single or batch), the total number of possible actions and a set of keyword + arguments, and returns an exploratory action (single or batch depending on the input), and b) a dictionary + of keyword arguments for the function in a) (this will be assigned to the ``_exploration_params`` member + variable). Defaults to (``epsilon_greedy``, {"epsilon": 0.1}). + exploration_scheduling_options (List[tuple]): A list of 3-tuples specifying the exploration schedulers to be + registered to the exploration parameters. Each tuple consists of an exploration parameter name, an + exploration scheduler class (subclass of ``AbsExplorationScheduler``) and keyword arguments for that class. + The exploration parameter name must be a key in the keyword arguments (second element) of + ``exploration_strategy``. Defaults to an empty list. + replay_memory_capacity (int): Capacity of the replay memory. Defaults to 1000000. + random_overwrite (bool): This specifies overwrite behavior when the replay memory capacity is reached. If True, + overwrite positions will be selected randomly. Otherwise, overwrites will occur sequentially with + wrap-around. Defaults to False. + warmup (int): When the total number of experiences in the replay memory is below this threshold, + ``choose_action`` will return uniformly random actions for warm-up purposes. Defaults to 50000. + rollout_batch_size (int): Size of the experience batch to use as roll-out information by calling + ``get_rollout_info``. Defaults to 1000. + train_batch_size (int): Batch size for training the Q-net. Defaults to 32. + prioritized_replay_kwargs (dict): Keyword arguments for prioritized experience replay. See + ``PrioritizedExperienceReplay`` for details. Defaults to None, in which case experiences will be sampled + from the replay memory uniformly randomly. + device (str): Identifier for the torch device. The ``q_net`` will be moved to the specified device. If it is + None, the device will be set to "cpu" if cuda is unavailable and "cuda" otherwise. Defaults to None. + """ + def __init__( + self, + name: str, + q_net: DiscreteQNetwork, + reward_discount: float = 0.9, + num_epochs: int = 1, + update_target_every: int = 5, + soft_update_coef: float = 0.1, + double: bool = False, + exploration_strategy: Tuple[Callable, dict] = (epsilon_greedy, {"epsilon": 0.1}), + exploration_scheduling_options: List[tuple] = None, + replay_memory_capacity: int = 1000000, + random_overwrite: bool = False, + warmup: int = 50000, + rollout_batch_size: int = 1000, + train_batch_size: int = 32, + prioritized_replay_kwargs: dict = None, + device: str = None + ): + super(DQN, self).__init__(name=name, device=device) + + if exploration_scheduling_options is None: + exploration_scheduling_options = [] + + if not isinstance(q_net, DiscreteQNetwork): + raise TypeError("Model must be an instance of 'DiscreteQNetwork'") + + if any(opt[0] not in exploration_strategy[1] for opt in exploration_scheduling_options): + raise ValueError( + f"The first element of an exploration scheduling option must be one of " + f"{list(exploration_strategy[1].keys())}" + ) + + self._q_net = q_net.to(self._device) + self._target_q_net: DiscreteQNetwork = clone(q_net) + self._target_q_net.eval() + self._q_net_version = 0 + self._target_q_net_version = 0 + + self._num_actions = self._q_net.action_num + self._reward_discount = reward_discount + self._num_epochs = num_epochs + self._update_target_every = update_target_every + self._soft_update_coef = soft_update_coef + self._double = double + + self._replay_memory = ReplayMemory( + replay_memory_capacity, self._q_net.state_dim, + self._q_net.action_dim, random_overwrite=random_overwrite + ) + self._warmup = warmup + self._rollout_batch_size = rollout_batch_size + self._train_batch_size = train_batch_size + self._prioritized_replay = prioritized_replay_kwargs is not None + if self._prioritized_replay: + self._per = PrioritizedExperienceReplay(self._replay_memory, **prioritized_replay_kwargs) + else: + self._loss_func = torch.nn.MSELoss() + + self._exploration_func = exploration_strategy[0] + self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing + self._exploration_schedulers = [ + opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options + ] + + def _call_impl(self, states: np.ndarray) -> Iterable: + if self._replay_memory.size < self._warmup: + return np.random.randint(self._num_actions, size=(states.shape[0] if len(states.shape) > 1 else 1,)) + + self._q_net.eval() + states: torch.Tensor = torch.from_numpy(states).to(self._device) + if len(states.shape) == 1: + states = states.unsqueeze(dim=0) + with torch.no_grad(): + actions = self._q_net.get_actions(states, exploring=False) # [batch_size, 1] + actions = actions.squeeze(-1) # [batch_size] + + if not self._exploring: + return actions.cpu().numpy() + else: + return self._exploration_func(states, actions.cpu().numpy(), self._num_actions, **self._exploration_params) + + def data_parallel(self, *args, **kwargs) -> None: + raise NotImplementedError # TODO + + def _get_q_values_for_all_actions(self, states: np.ndarray) -> np.ndarray: + return self._q_net.q_values_for_all_actions(torch.Tensor(states)).numpy() + + def _get_q_values(self, states: np.ndarray, actions: np.ndarray) -> np.ndarray: + q_matrix = self.q_values_for_all_actions(states) # [batch_size, action_num] + return np.take_along_axis(q_matrix, actions, axis=1) + + def _get_action_num(self) -> int: + return self._q_net.action_num + + def _get_state_dim(self) -> int: + return self._q_net.state_dim + + def record( + self, + key: str, + state: np.ndarray, + action: Union[int, float, np.ndarray], + reward: float, + next_state: np.ndarray, + terminal: bool + ) -> None: + if next_state is None: + next_state = np.zeros(state.shape, dtype=np.float32) + + indexes = self._replay_memory.put( + np.expand_dims(state, axis=0), + np.expand_dims(action, axis=0), + np.expand_dims(reward, axis=0), + np.expand_dims(next_state, axis=0), + np.expand_dims(terminal, axis=0) + ) + if self._prioritized_replay: + self._per.set_max_priority(indexes) + + def get_rollout_info(self) -> dict: + """Randomly sample a batch of transitions from the replay memory. + + This is used in a distributed learning setting and the returned data will be sent to its parent instance + on the learning side (serving as the source of the latest model parameters) for training. + """ + return self._replay_memory.sample(self._rollout_batch_size) + + def _get_batch(self, batch_size: int = None) -> dict: + if batch_size is None: + batch_size = self._train_batch_size + if self._prioritized_replay: + indexes, is_weights = self._per.sample(batch_size) + return { + "states": self._replay_memory.states[indexes], + "actions": self._replay_memory.actions[indexes], + "rewards": self._replay_memory.rewards[indexes], + "next_states": self._replay_memory.next_states[indexes], + "terminals": self._replay_memory.terminals[indexes], + "indexes": indexes, + "is_weights": is_weights + } + else: + return self._replay_memory.sample(self._train_batch_size) + + def get_batch_loss(self, batch: dict, explicit_grad: bool = False) -> dict: + """Compute loss for a data batch. + + Args: + batch (dict): A batch containing "states", "actions", "rewards", "next_states" and "terminals" as keys. + explicit_grad (bool): If True, the gradients should be returned as part of the loss information. Defaults + to False. + """ + self._q_net.train() + states: torch.Tensor = torch.from_numpy(batch["states"]).to(self._device) + next_states: torch.Tensor = torch.from_numpy(batch["next_states"]).to(self._device) + actions: torch.Tensor = torch.from_numpy(batch["actions"]).to(self._device) + rewards: torch.Tensor = torch.from_numpy(batch["rewards"]).to(self._device) + terminals: torch.Tensor = torch.from_numpy(batch["terminals"]).float().to(self._device) + + # get target Q values + with torch.no_grad(): + if self._double: + actions_by_eval_q_net = self._q_net.get_actions(next_states, exploring=False) + next_q_values = self._target_q_net.q_values(next_states, actions_by_eval_q_net) + else: + actions = self._target_q_net.get_actions(next_states, exploring=False) + next_q_values = self._target_q_net.q_values(next_states, actions) + + target_q_values = (rewards + self._reward_discount * (1 - terminals) * next_q_values).detach() + + # loss info + loss_info = {} + q_values = self._q_net.q_values(states, actions) + td_errors = target_q_values - q_values + if self._prioritized_replay: + is_weights = torch.from_numpy(batch["is_weights"]).to(self._device) + loss = (td_errors * is_weights).mean() + loss_info["td_errors"], loss_info["indexes"] = td_errors.detach().cpu().numpy(), batch["indexes"] + else: + loss = self._loss_func(q_values, target_q_values) + + loss_info["loss"] = loss.detach().cpu().numpy() if explicit_grad else loss + if explicit_grad: + loss_info["grad"] = self._q_net.get_gradients(loss) + return loss_info + + def update(self, loss_info_list: List[dict]) -> None: + """Update the Q-net parameters with gradients computed by multiple gradient workers. + + Args: + loss_info_list (List[dict]): A list of dictionaries containing loss information (including gradients) + computed by multiple gradient workers. + """ + if self._prioritized_replay: + for loss_info in loss_info_list: + self._per.update(loss_info["indexes"], loss_info["td_errors"]) + + self._q_net.apply_gradients(average_grads([loss_info["grad"] for loss_info in loss_info_list])) + self._q_net_version += 1 + # soft-update target network + if self._q_net_version - self._target_q_net_version == self._update_target_every: + self._update_target() + + def learn(self, batch: dict) -> None: + """Learn from a batch containing data required for policy improvement. + + Args: + batch (dict): A batch containing "states", "actions", "rewards", "next_states" and "terminals" as keys. + """ + self._replay_memory.put( + batch["states"], batch["actions"], batch["rewards"], batch["next_states"], batch["terminals"] + ) + self.improve() + + def improve(self) -> None: + """Learn using data from the replay memory.""" + for _ in range(self._num_epochs): + loss_info = self.get_batch_loss(self._get_batch()) + if self._prioritized_replay: + self._per.update(loss_info["indexes"], loss_info["td_errors"]) + self._q_net.step(loss_info["loss"]) + self._q_net_version += 1 + if self._q_net_version - self._target_q_net_version == self._update_target_every: + self._update_target() + + def _update_target(self) -> None: + self._target_q_net.soft_update(self._q_net, self._soft_update_coef) + self._new_target_q_net_version = self._q_net_version + + def learn_with_data_parallel(self, batch: dict, worker_id_list: list) -> None: + assert hasattr(self, '_proxy'), "learn_with_data_parallel is invalid before data_parallel is called." + + self._replay_memory.put( + batch["states"], batch["actions"], batch["rewards"], batch["next_states"], batch["terminals"] + ) + for _ in range(self._num_epochs): + msg_dict = defaultdict(lambda: defaultdict(dict)) + for worker_id in worker_id_list: + msg_dict[worker_id][MsgKey.GRAD_TASK][self._name] = self._get_batch( + self._train_batch_size // len(worker_id_list)) + msg_dict[worker_id][MsgKey.POLICY_STATE][self._name] = self.get_state() + # data-parallel by multiple remote gradient workers + self._proxy.isend(SessionMessage( + MsgTag.COMPUTE_GRAD, self._proxy.name, worker_id, body=msg_dict[worker_id])) + dones = 0 + loss_info_by_policy = {self._name: []} + for msg in self._proxy.receive(): + if msg.tag == MsgTag.COMPUTE_GRAD_DONE: + for policy_name, loss_info in msg.body[MsgKey.LOSS_INFO].items(): + if isinstance(loss_info, list): + loss_info_by_policy[policy_name] += loss_info + elif isinstance(loss_info, dict): + loss_info_by_policy[policy_name].append(loss_info) + else: + raise TypeError(f"Wrong type of loss_info: {type(loss_info)}") + dones += 1 + if dones == len(msg_dict): + break + # build dummy computation graph before apply gradients. + _ = self.get_batch_loss(self._get_batch(), explicit_grad=True) + self.update(loss_info_by_policy[self._name]) + + def exploration_step(self) -> None: + """Update the exploration parameters according to the exploration scheduler.""" + for sch in self._exploration_schedulers: + sch.step() + + def get_state(self) -> object: + return self._q_net.get_state() + + def set_state(self, policy_state: object) -> None: + self._q_net.set_state(policy_state) + + def load(self, path: str) -> None: + """Load the policy state from disk.""" + checkpoint = torch.load(path) + self._q_net.set_state(checkpoint["q_net"]) + self._q_net_version = checkpoint["q_net_version"] + self._target_q_net.set_state(checkpoint["target_q_net"]) + self._target_q_net_version = checkpoint["target_q_net_version"] + self._replay_memory = checkpoint["replay_memory"] + if self._prioritized_replay: + self._per = checkpoint["prioritized_replay"] + + def save(self, path: str) -> None: + """Save the policy state to disk.""" + policy_state = { + "q_net": self._q_net.get_state(), + "q_net_version": self._q_net_version, + "target_q_net": self._target_q_net.get_state(), + "target_q_net_version": self._target_q_net_version, + "replay_memory": self._replay_memory + } + if self._prioritized_replay: + policy_state["prioritized_replay"] = self._per + torch.save(policy_state, path) diff --git a/maro/rl/policy_v2/maac.py b/maro/rl/policy_v2/maac.py new file mode 100644 index 000000000..ee97e5e1d --- /dev/null +++ b/maro/rl/policy_v2/maac.py @@ -0,0 +1,271 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from collections import defaultdict +from typing import Callable, Iterable, List, Tuple + +import numpy as np +import torch + +from maro.rl.modeling_v2 import DiscretePolicyGradientNetwork +from maro.rl.modeling_v2.critic_model import MultiDiscreteQCriticNetwork +from maro.rl.policy_v2 import RLPolicy +from maro.rl.policy_v2.buffer import MultiBuffer +from maro.rl.policy_v2.policy_interfaces import MultiDiscreteActionMixin +from maro.rl.utils import average_grads + + +class MultiDiscreteActorCritic(MultiDiscreteActionMixin, RLPolicy): + """ + References: + MADDPG paper: https://arxiv.org/pdf/1706.02275.pdf + + Args: + name (str): Unique identifier for the policy. + global_state_dim (int): State dim of the shared part of state. + agent_nets (List[DiscretePolicyGradientNetwork]): Networks for all sub-agents. + critic_net (MultiDiscreteQCriticNetwork): Critic's network. + reward_discount (float): Reward decay as defined in standard RL terminology. + grad_iters (int): Number of gradient steps for each batch or set of batches. Defaults to 1. + min_logp (float): Lower bound for clamping logP values during learning. This is to prevent logP from becoming + very large in magnitude and causing stability issues. Defaults to None, which means no lower bound. + critic_loss_cls: A string indicating a loss class provided by torch.nn or a custom loss class for computing + the critic loss. If it is a string, it must be a key in ``TORCH_LOSS``. Defaults to "mse". + critic_loss_coef (float): Coefficient of critic loss. + clip_ratio (float): Clip ratio in the PPO algorithm (https://arxiv.org/pdf/1707.06347.pdf). Defaults to None, + in which case the actor loss is calculated using the usual policy gradient theorem. + lam (float): Lambda value for generalized advantage estimation (TD-Lambda). Defaults to 0.9. + max_trajectory_len (int): Maximum trajectory length that can be held by the buffer (for each agent that uses + this policy). Defaults to 10000. + get_loss_on_rollout (bool): If True, ``get_rollout_info`` will return the loss information (including gradients) + for the trajectories stored in the buffers. The loss information, along with that from other roll-out + instances, can be passed directly to ``update``. Otherwise, it will simply process the trajectories into a + single data batch that can be passed directly to ``learn``. Defaults to False. + device (str): Identifier for the torch device. The ``ac_net`` will be moved to the specified device. If it is + None, the device will be set to "cpu" if cuda is unavailable and "cuda" otherwise. Defaults to None. + """ + def __init__( + self, + name: str, + global_state_dim: int, + agent_nets: List[DiscretePolicyGradientNetwork], + critic_net: MultiDiscreteQCriticNetwork, + reward_discount: float, + grad_iters: int = 1, + min_logp: float = None, + critic_loss_cls: Callable = None, + critic_loss_coef: float = 1.0, + clip_ratio: float = None, + lam: float = 0.9, + max_trajectory_len: int = 10000, + get_loss_on_rollout: bool = False, + device: str = None + ) -> None: + super(MultiDiscreteActorCritic, self).__init__(name=name, device=device) + + self._critic_net = critic_net + self._total_state_dim = self._critic_net.state_dim + self._global_state_dim = global_state_dim + + self._agent_nets = agent_nets + self._num_sub_agents = len(self._agent_nets) + self._local_state_dims = [net.state_dim - self._global_state_dim for net in self._agent_nets] + assert all(dim >= 0 for dim in self._local_state_dims) + assert self._total_state_dim == sum(self._local_state_dims) + self._global_state_dim + + self._reward_discount = reward_discount + self._grad_iters = grad_iters + self._min_logp = min_logp + self._critic_loss_func = critic_loss_cls() if critic_loss_cls is not None else torch.nn.MSELoss() + self._critic_loss_coef = critic_loss_coef + self._clip_ratio = clip_ratio + self._lam = lam + self._max_trajectory_len = max_trajectory_len + self._get_loss_on_rollout = get_loss_on_rollout + + self._buffer = defaultdict(lambda: MultiBuffer(agent_num=self._num_sub_agents, size=self._max_trajectory_len)) + + def _get_action_nums(self) -> List[int]: + return [net.action_num for net in self._agent_nets] + + def _get_state_dim(self) -> int: + return self._critic_net.state_dim + + def _call_impl(self, states: np.ndarray) -> Iterable: + actions, logps, values = self.get_actions_with_logps_and_values(states) + return [ + { + "action": action, # [num_sub_agent] + "logp": logp, # [num_sub_agent] + "value": value # Scalar + } for action, logp, value in zip(actions, logps, values) + ] + + def _get_state_list(self, input_states: np.ndarray) -> List[torch.Tensor]: + """Get observable states for all sub-agents. + + Args: + input_states (np.ndarray): global state with shape [batch_size, total_state_dim] + + Returns: + A list of torch.Tensor. + + """ + state_list = [] + global_state = input_states[:, -self._global_state_dim] # [batch_size, global_state_dim] + offset = 0 + for local_state_dim in self._local_state_dims: + local_state = input_states[:, offset:offset + local_state_dim] # [batch_size, local_state_dim] + offset += local_state_dim + + complete_state = np.concatenate([local_state, global_state], axis=1) # [batch_size, complete_state_dim] + complete_state = torch.from_numpy(complete_state).to(self._device) + if len(complete_state.shape) == 1: + complete_state = complete_state.unsqueeze(dim=0) + state_list.append(complete_state) + return state_list + + def get_actions_with_logps_and_values(self, input_states: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + + Args: + input_states (np.ndarray): global state with shape [batch_size, total_state_dim] + + Returns: + actions: [batch_size, num_sub_agent] + logps: [batch_size, num_sub_agent] + values: [batch_size] + """ + for net in self._agent_nets: + net.eval() + + state_list = self._get_state_list(input_states) + with torch.no_grad(): + actions = [] + logps = [] + for net, state in zip(self._agent_nets, state_list): # iterate `num_sub_agent` times + action, logp = net.get_actions_and_logps(state, self._exploring) # [batch_size], [batch_size] + actions.append(action) + logps.append(logp) + values = self._get_values_by_states_and_actions(torch.from_numpy(input_states).to(self._device), actions) + + actions = np.stack([action.cpu().numpy() for action in actions], axis=1) # [batch_size, num_sub_agent] + logps = np.stack([logp.cpu().numpy() for logp in logps], axis=1) # [batch_size, num_sub_agent] + values = values.cpu().numpy() # [batch_size] + + return actions, logps, values + + def _get_values_by_states_and_actions(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor: + """ + states: [batch_size, state_dim] + actions: List of torch.Tensor with shape [batch_size] + + Returns: + [batch_size] + """ + action_tensor = torch.stack(actions).T # [batch_size, sub_agent_num] + return self._critic_net.q_critic(states, action_tensor) + + def record( + self, key: str, state: np.ndarray, action: dict, reward: float, + next_state: np.ndarray, terminal: bool + ) -> None: + self._buffer[key].put(state, action, reward, terminal) + + def get_rollout_info(self) -> dict: + if self._get_loss_on_rollout: + return self.get_batch_loss(self._get_batch(), explicit_grad=True) + else: + return self._get_batch() + + def _get_batch(self) -> dict: + batch = defaultdict(list) + for buf in self._buffer.values(): + trajectory = buf.get() + batch["states"].append(trajectory["states"][:-1]) + batch["actions"].append(trajectory["actions"][:-1]) + batch["next_states"].append(trajectory["next_states"][1:]) + batch["next_actions"].append(trajectory["next_actions"][1:]) + batch["rewards"].append(trajectory["rewards"][:-1]) + batch["terminals"].append(trajectory["terminals"][:-1]) + return {key: np.concatenate(vals) for key, vals in batch.items()} # batch_size = sum(buffer_length) + + def get_batch_loss(self, batch: dict, explicit_grad: bool = False) -> dict: + for i, net in enumerate(self._agent_nets): + net.train() + self._critic_net.train() + + states = torch.from_numpy(batch["states"]).to(self._device) # [batch_size, total_state_dim] + actions = [torch.from_numpy(elem).to(self._device).long() for elem in batch["actions"]] + next_states = torch.from_numpy(["next_states"]).to(self._device) # [batch_size, total_state_dim] + next_actions = [torch.from_numpy(elem).to(self._device).long() for elem in batch["next_actions"]] + + rewards = torch.from_numpy(batch["rewards"]).to(self._device) # [batch_size] + terminals = torch.from_numpy(batch["terminals"]).float().to(self._device) # [batch_size] + + # critic loss + with torch.no_grad(): + next_q_values = self._get_values_by_states_and_actions(next_states, next_actions) + target_q_values = (rewards + self._reward_discount * (1 - terminals) * next_q_values).detach() # [batch_size] + q_values = self._get_values_by_states_and_actions(states, actions) # [batch_size] + critic_loss = self._critic_loss_func(q_values, target_q_values) + + # actor losses + state_list = self._get_state_list(states) + actor_losses = [] + for i in range(self._num_sub_agents): + net = self._agent_nets[i] + state = state_list[i] + new_action, _ = net.get_actions_and_logps(state, self._exploring) # [batch_size], [batch_size] + cur_actions = [action for action in actions] + cur_actions[i] = new_action + actor_loss = -self._get_values_by_states_and_actions(states, cur_actions).mean() + actor_losses.append(actor_loss) + + # total loss + loss = sum(actor_losses) + self._critic_loss_coef * critic_loss + + loss_info = { + "critic_loss": critic_loss.detach().cpu().numpy(), + "actor_losses": [loss.detach().cpu().numpy() for loss in actor_losses], + "loss": loss.detach().cpu().numpy() if explicit_grad else loss + } + if explicit_grad: + loss_info["actor_grads"] = [net.get_gradients(loss) for net in self._agent_nets] + loss_info["critic_grad"] = self._critic_net.get_gradients(loss) + + return loss_info + + def data_parallel(self, *args, **kwargs) -> None: + pass # TODO + + def learn_with_data_parallel(self, batch: dict, worker_id_list: list) -> None: + pass # TODO + + def update(self, loss_info_list: List[dict]) -> None: + for i, net in enumerate(self._agent_nets): + net.apply_gradients(average_grads([loss_info["actor_grads"][i] for loss_info in loss_info_list])) + self._critic_net.apply_gradients(average_grads([loss_info["critic_grad"] for loss_info in loss_info_list])) + + def learn(self, batch: dict) -> None: + for _ in range(self._grad_iters): + loss = self.get_batch_loss(batch)["loss"] + for net in self._agent_nets: + net.step(loss) + self._critic_net.step(loss) + + def improve(self) -> None: + self.learn(self._get_batch()) + + def get_state(self) -> object: + return [net.get_state() for net in self._agent_nets] + + def set_state(self, policy_state: object) -> None: + assert isinstance(policy_state, list) + for net, state in zip(self._agent_nets, policy_state): + net.set_state(state) + + def load(self, path: str) -> None: + self.set_state(torch.load(path)) + + def save(self, path: str) -> None: + torch.save(self.get_state(), path) diff --git a/maro/rl/policy_v2/pg.py b/maro/rl/policy_v2/pg.py new file mode 100644 index 000000000..2643a597b --- /dev/null +++ b/maro/rl/policy_v2/pg.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import defaultdict +from typing import List, Tuple + +import numpy as np +import torch + +from maro.communication import SessionMessage +from maro.rl.modeling_v2 import DiscretePolicyGradientNetwork +from maro.rl.utils import MsgKey, MsgTag, average_grads, discount_cumsum + +from .buffer import Buffer +from .policy_base import RLPolicy +from .policy_interfaces import DiscreteActionMixin + + +class DiscretePolicyGradient(DiscreteActionMixin, RLPolicy): + """The vanilla Policy Gradient (VPG) algorithm, a.k.a., REINFORCE. + + Reference: https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch. + + Args: + name (str): Unique identifier for the policy. + policy_net (DiscretePolicyNet): Multi-task model that computes action distributions and state values. + It may or may not have a shared bottom stack. + reward_discount (float): Reward decay as defined in standard RL terminology. + grad_iters (int): Number of gradient steps for each batch or set of batches. Defaults to 1. + max_trajectory_len (int): Maximum trajectory length that can be held by the buffer (for each agent that uses + this policy). Defaults to 10000. + get_loss_on_rollout (bool): If True, ``get_rollout_info`` will return the loss information (including gradients) + for the trajectories stored in the buffers. The loss information, along with that from other roll-out + instances, can be passed directly to ``update``. Otherwise, it will simply process the trajectories into a + single data batch that can be passed directly to ``learn``. Defaults to False. + device (str): Identifier for the torch device. The ``policy net`` will be moved to the specified device. If it + is None, the device will be set to "cpu" if cuda is unavailable and "cuda" otherwise. Defaults to None. + """ + def __init__( + self, + name: str, + policy_net: DiscretePolicyGradientNetwork, + reward_discount: float, + grad_iters: int = 1, + max_trajectory_len: int = 10000, + get_loss_on_rollout: bool = False, + device: str = None + ) -> None: + super(DiscretePolicyGradient, self).__init__(name=name, device=device) + + if not isinstance(policy_net, DiscretePolicyGradientNetwork): + raise TypeError("model must be an instance of 'DiscretePolicyGradientNetwork'") + + self._policy_net = policy_net.to(self._device) + self._reward_discount = reward_discount + self._grad_iters = grad_iters + self._max_trajectory_len = max_trajectory_len + self._get_loss_on_rollout = get_loss_on_rollout + + self._buffer = defaultdict(lambda: Buffer(size=self._max_trajectory_len)) + + def _call_impl(self, states: np.ndarray) -> List[dict]: + """Return a list of action information dict given a batch of states. + + An action information dict contains the action itself and the corresponding log-P value. + """ + actions, logps = self.get_actions_with_logps(states) + return [{"action": action, "logp": logp} for action, logp in zip(actions, logps)] + + def get_actions_with_logps(self, states: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Return actions an log-P value based on states. + """ + self._policy_net.eval() + with torch.no_grad(): + states: torch.Tensor = torch.from_numpy(states).to(self._device) + actions, logps = self._policy_net.get_actions_and_logps(states, exploring=self._exploring) + actions, logps = actions.cpu().numpy(), logps.cpu().numpy() + return actions, logps + + def data_parallel(self, *args, **kwargs) -> None: + raise NotImplementedError # TODO + + def _get_action_num(self) -> int: + return self._policy_net.action_num + + def _get_state_dim(self) -> int: + return self._policy_net.state_dim + + def record( + self, + key: str, + state: np.ndarray, + action: dict, + reward: float, + next_state: np.ndarray, + terminal: bool + ) -> None: + self._buffer[key].put(state, action, reward, terminal) + + def get_rollout_info(self) -> dict: + """Extract information from the recorded transitions. + + Returns: + Loss (including gradients) for the latest trajectory segment in the replay buffer if ``get_loss_on_rollout`` + is True or the latest trajectory segment with pre-computed return values. + """ + if self._get_loss_on_rollout: + return self.get_batch_loss(self._get_batch(), explicit_grad=True) + else: + return self._get_batch() + + def _get_batch(self) -> dict: + batch = defaultdict(list) + for buf in self._buffer: + trajectory = buf.get() + rewards = np.append(trajectory["rewards"], trajectory["last_val"]) + batch["states"].append(trajectory["states"]) + # Returns rewards-to-go, to be targets for the value function + batch["returns"].append(discount_cumsum(rewards, self._reward_discount)[:-1]) + + return {key: np.concatenate(vals) for key, vals in batch.items()} + + def get_batch_loss(self, batch: dict, explicit_grad: bool = False) -> dict: + """Compute AC loss for a data batch. + + Args: + batch (dict): A batch containing "states" and "returns" as keys. + explicit_grad (bool): If True, the gradients should be returned as part of the loss information. Defaults + to False. + """ + self._policy_net.train() + returns = torch.from_numpy(np.asarray(batch["returns"])).to(self._device) + + logps = self._policy_net.get_logps(batch["states"]) + loss = -(logps * returns).mean() + loss_info = {"loss": loss.detach().cpu().numpy() if explicit_grad else loss} + if explicit_grad: + loss_info["grad"] = self._policy_net.get_gradients(loss) + return loss_info + + def update(self, loss_info_list: List[dict]) -> None: + """Update the model parameters with gradients computed by multiple roll-out instances or gradient workers. + + Args: + loss_info_list (List[dict]): A list of dictionaries containing loss information (including gradients) + computed by multiple roll-out instances or gradient workers. + """ + self._policy_net.apply_gradients(average_grads([loss_info["grad"] for loss_info in loss_info_list])) + + def learn(self, batch: dict) -> None: + """Learn from a batch containing data required for policy improvement. + + Args: + batch (dict): A batch containing "states" and "returns" as keys. + """ + for _ in range(self._grad_iters): + self._policy_net.step(self.get_batch_loss(batch)["grad"]) + + def improve(self) -> None: + """Learn using data from the buffer.""" + self.learn(self._get_batch()) + + def learn_with_data_parallel(self, batch: dict, worker_id_list: list) -> None: + assert hasattr(self, '_proxy'), "learn_with_data_parallel is invalid before data_parallel is called." + + for _ in range(self._grad_iters): + msg_dict = defaultdict(lambda: defaultdict(dict)) + sub_batch = {} + for i, worker_id in enumerate(worker_id_list): + sub_batch = {key: batch[key][i::len(worker_id_list)] for key in batch} + msg_dict[worker_id][MsgKey.GRAD_TASK][self._name] = sub_batch + msg_dict[worker_id][MsgKey.POLICY_STATE][self._name] = self.get_state() + # data-parallel + self._proxy.isend(SessionMessage( + MsgTag.COMPUTE_GRAD, self._proxy.name, worker_id, body=msg_dict[worker_id])) + dones = 0 + loss_info_by_policy = {self._name: []} + for msg in self._proxy.receive(): + if msg.tag == MsgTag.COMPUTE_GRAD_DONE: + for policy_name, loss_info in msg.body[MsgKey.LOSS_INFO].items(): + if isinstance(loss_info, list): + loss_info_by_policy[policy_name] += loss_info + elif isinstance(loss_info, dict): + loss_info_by_policy[policy_name].append(loss_info["grad"]) + else: + raise TypeError(f"Wrong type of loss_info: {type(loss_info)}") + dones += 1 + if dones == len(msg_dict): + break + # build dummy computation graph before apply gradients. + _ = self.get_batch_loss(sub_batch, explicit_grad=True) + self._policy_net.step(loss_info_by_policy[self._name]) + + def get_state(self) -> object: + return self._policy_net.get_state() + + def set_state(self, policy_state: object) -> None: + self._policy_net.set_state(policy_state) + + def load(self, path: str) -> None: + """Load the policy state from disk.""" + self._policy_net.set_state(torch.load(path)) + + def save(self, path: str) -> None: + """Save the policy state to disk.""" + torch.save(self._policy_net.get_state(), path) diff --git a/maro/rl/policy_v2/policy_base.py b/maro/rl/policy_v2/policy_base.py new file mode 100644 index 000000000..9d6f7875a --- /dev/null +++ b/maro/rl/policy_v2/policy_base.py @@ -0,0 +1,292 @@ +from abc import abstractmethod +from typing import Iterable, List, Optional + +import numpy as np +import torch + +from maro.communication import Proxy +from maro.rl.policy_v2.policy_interfaces import ShapeCheckMixin +from maro.rl.utils import match_shape + + +class AbsPolicy(object): + """Abstract policy class. + + All concrete classes that inherit `AbsPolicy` should implement the following abstract methods: + - __call__(self, states: object) -> object: + - _get_state_dim(self) -> int: + """ + + def __init__(self, name: str) -> None: + """ + + Args: + name (str): Unique identifier for the policy. + """ + super().__init__() + print(f"Initializing {self.__class__.__module__}.{self.__class__.__name__}") + self._name = name + + @property + def name(self) -> str: + return self._name + + @abstractmethod + def __call__(self, states: object) -> object: + """Get actions and other auxiliary information based on states. + + Args: + states (object): environment states. + + Returns: + Actions and other auxiliary information based on states. + The format of the returns is defined by the policy. + """ + pass + + @property + def state_dim(self) -> int: + return self._get_state_dim() + + @abstractmethod + def _get_state_dim(self) -> int: + pass + + +class DummyPolicy(AbsPolicy): + """Dummy policy that does nothing. + + Note that the meaning of a "None" action may depend on the scenario. + """ + + def __init__(self, name: str) -> None: + super(DummyPolicy, self).__init__(name) + + def __call__(self, states: object) -> object: + return None + + def _get_state_dim(self) -> int: + return -1 + + +class RuleBasedPolicy(AbsPolicy): + """ + Rule-based policy that generates actions according to a fixed rule. + The rule is immutable, which means a rule-based policy is not trainable. + + All concrete classes that inherit `RuleBasedPolicy` should implement the following abstract methods: + - Declared in `AbsPolicy`: + - _get_state_dim(self) -> int: + - Declared in `RuleBasedPolicy`: + - _rule(self, state: object) -> object: + """ + + def __init__(self, name: str) -> None: + super(RuleBasedPolicy, self).__init__(name) + + def __call__(self, states: object) -> object: + return self._rule(states) + + @abstractmethod + def _rule(self, state: object) -> object: + """The rule that should be implemented by inheritors.""" + pass + + +class RLPolicy(ShapeCheckMixin, AbsPolicy): + """Policy that learns from simulation experiences. + Reinforcement learning (RL) policies should inherit from this. + + All concrete classes that inherit `RLPolicy` should implement the following abstract methods: + - Declared in `AbsPolicy`: + - _get_state_dim(self) -> int: + - Declared in `RLPolicy`: + - _call_impl(self, states: np.ndarray) -> Iterable: + - record(self, ...) -> None: + - get_rollout_info(self) -> object: + - get_batch_loss(self, batch: dict, explicit_grad: bool = False) -> object: + - data_parallel(self, *args, **kwargs) -> None: + - learn_with_data_parallel(self, batch: dict, worker_id_list: list) -> None: + - update(self, loss_info_list: List[dict]) -> None: + - learn(self, batch: dict) -> None: + - improve(self) -> None: + - get_state(self) -> object: + - set_state(self, policy_state: object) -> None: + - load(self, path: str) -> None: + - save(self, path: str) -> None: + """ + def __init__(self, name: str, device: str) -> None: + """ + Args: + name (str): Name of the policy. + device (str): Device that uses to train the Torch model. + """ + super(RLPolicy, self).__init__(name) + self._exploration_params = {} + self._exploring = True + self._proxy = Optional[Proxy] + + self._device = torch.device(device) if device is not None \ + else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + @property + def exploration_params(self) -> dict: + return self._exploration_params + + def explore(self) -> None: + """Switch the policy to the exploring mode.""" + self._exploring = True + + def exploit(self) -> None: + """Switch the policy to the exploiting mode.""" + self._exploring = False + + def __call__(self, states: np.ndarray) -> Iterable: + assert self._shape_check(states, None) + return self._call_impl(states) + + @abstractmethod + def _call_impl(self, states: np.ndarray) -> Iterable: + """The implementation of `__call__` method. Actual logic should be implemented under this method.""" + pass + + def _shape_check(self, states: np.ndarray, actions: Optional[np.ndarray]) -> bool: + return all([ + states.shape[0] > 0 and match_shape(states, (None, self.state_dim)), + actions is None or (actions.shape[0] > 0 and match_shape(actions, (None, 1))), + actions is None or states.shape[0] == actions.shape[0] + ]) + + @abstractmethod + def record( + self, + agent_id: str, + state: np.ndarray, + action: object, + reward: float, + next_state: np.ndarray, + terminal: bool + ) -> None: + """Record a transition in an internal buffer or memory. + + Since we may have multiple agents sharing this policy, the internal buffer / memory should use the agents' + names to separate storage for these agents. The ``agent_id`` parameter serves this purpose. + """ + pass + + @abstractmethod + def get_rollout_info(self) -> object: # TODO: return type? + """Extract information from the recorded transitions. + + Implement this method if you are doing distributed learning. What this function returns will be used to update + policy parameters on the learning side (abstracted through ``AbsPolicyManager``) with or without invoking the + policy improvement algorithm, depending on the type of information. If you want the policy improvement algorithm + to be invoked on roll-out instances (i.e., in distributed fashion), this should return loss information (which + can be obtained by calling ``get_batch_loss`` function with ``explicit_grad`` set to True) to be used by + ``update`` on the learning side. If you want the policy improvement algorithm to be invoked on the learning + side, this should return a data batch to be used by ``learn`` on the learning side. See the implementation of + this function in ``ActorCritic`` for reference. + """ + pass + + @abstractmethod + def get_batch_loss(self, batch: dict, explicit_grad: bool = False) -> object: # TODO: return type? + """Compute policy improvement information, i.e., loss, from a data batch. + + This can be used as a sub-routine in ``learn`` and ``improve``, as these methods usually require computing + loss from a batch. + + Args: + batch (dict): Data batch to compute the policy improvement information for. + explicit_grad (bool): If True, the gradients should be explicitly returned. Defaults to False. + """ + pass + + @abstractmethod + def data_parallel(self, *args, **kwargs) -> None: + """"Initialize a proxy in the policy, for data-parallel training. + Using the same arguments as `Proxy`.""" + pass + + def data_parallel_with_existing_proxy(self, proxy: Proxy) -> None: + """"Initialize a proxy in the policy with an existing one, for data-parallel training.""" + self._proxy = proxy + + def exit_data_parallel(self) -> None: + if self._proxy is not None: + self._proxy.close() + + @abstractmethod + def learn_with_data_parallel(self, batch: dict, worker_id_list: list) -> None: + pass + + @abstractmethod + def update(self, loss_info_list: List[dict]) -> None: + """Update with loss information computed by multiple sources. + + There are two possible scenarios where you need to implement this interface: 1) if you are doing distributed + learning and want each roll-out instance to collect information that can be used to update policy parameters + on the learning side (abstracted through ``AbsPolicyManager``) without invoking the policy improvement + algorithm. Such information usually includes gradients with respect to the policy parameters. An example where + this can be useful is the Asynchronous Advantage Actor Critic (A3C) (https://arxiv.org/abs/1602.01783); + 2) if you are computing loss in data-parallel fashion, i.e., by splitting a data batch to several smaller + batches and sending them to a set of remote workers for parallelized loss computation. + + Args: + loss_info_list (List[dict]): A list of dictionaries containing loss information (e.g., gradients) computed + by multiple sources. + """ + pass + + @abstractmethod + def learn(self, batch: dict) -> None: + """Learn from a batch of roll-out data. + + Implement this interface if you are doing distributed learning and want the roll-out instances to collect + information that can be used to update policy parameters on the learning side (abstracted through + ``AbsPolicyManager``) using the policy improvement algorithm. + + Args: + batch (dict): Training data to train the policy with. + """ + pass + + @abstractmethod + def improve(self) -> None: + """Learn using data collected locally. + + Implement this interface if you are doing single-threaded learning where a single policy instance is used for + roll-out and training. The policy should have some kind of internal buffer / memory to store roll-out data and + use as the source of training data. + """ + pass + + @abstractmethod + def get_state(self) -> object: + """Return the current state of the policy. + + The implementation must be in correspondence with that of ``set_state``. For example, if a torch model + is contained in the policy, ``get_state`` may include a call to ``state_dict()`` on the model, while + ``set_state`` should accordingly include ``load_state_dict()``. + """ + pass + + @abstractmethod + def set_state(self, policy_state: object) -> None: + """Set the policy state to ``policy_state``. + + The implementation must be in correspondence with that of ``get_state``. For example, if a torch model + is contained in the policy, ``set_state`` may include a call to ``load_state_dict()`` on the model, while + ``get_state`` should accordingly include ``state_dict()``. + """ + pass + + @abstractmethod + def load(self, path: str) -> None: + """Load the policy state from disk.""" + pass + + @abstractmethod + def save(self, path: str) -> None: + """Save the policy state to disk.""" + pass diff --git a/maro/rl/policy_v2/policy_interfaces.py b/maro/rl/policy_v2/policy_interfaces.py new file mode 100644 index 000000000..a61b7cbf7 --- /dev/null +++ b/maro/rl/policy_v2/policy_interfaces.py @@ -0,0 +1,186 @@ +from abc import abstractmethod +from typing import List, Optional, Tuple, Union + +import numpy as np + +from maro.rl.utils import match_shape + +""" +Mixins for policies. + +Mixins have only methods (abstract or non abstract) which define a set of functions that a type of policies should have. +Abstract methods should be implemented by lower-level mixins or policy classes that inherit the mixin. + +A policy class could inherit multiple mixins so that the combination of mixins determines the entire set of methods +of this policy. +""" + + +class DiscreteActionMixin: + """ + Mixin for policies that generate discrete actions. + + All concrete classes that inherit `DiscreteActionMixin` should implement the following abstract methods: + - _get_action_num(self) -> int: + """ + + @property + def action_num(self) -> int: + return self._get_action_num() + + @abstractmethod + def _get_action_num(self) -> int: + pass + + +class MultiDiscreteActionMixin: + """ + Mixin for multi-agent policies that generate discrete actions. + + All concrete classes that inherit `MultiDiscreteActionMixin` should implement the following abstract methods: + - _get_action_nums(self) -> List[int]: + """ + + @property + def action_nums(self) -> List[int]: + return self._get_action_nums() + + @abstractmethod + def _get_action_nums(self) -> List[int]: + pass + + +class ContinuousActionMixin: + """ + Mixin for policies that generate continuous actions. + + All concrete classes that inherit `ContinuousActionMixin` should implement the following abstract methods: + - _get_action_range(self) -> Tuple[Union[int, float, np.ndarray], Union[int, float, np.ndarray]]: + """ + + def action_range(self) -> Tuple[Union[int, float, np.ndarray], Union[int, float, np.ndarray]]: + return self._get_action_range() + + @abstractmethod + def _get_action_range(self) -> Tuple[Union[int, float, np.ndarray], Union[int, float, np.ndarray]]: + pass + + +class ShapeCheckMixin: + """ + Mixin that contains the `_shape_check` method, which is used for checking whether the states and actions + have valid shapes. Usually, it should contains three parts: + 1. Check of states' shape. + 2. Check of actions' shape. + 3. Check whether states and actions have identical batch sizes. + + `actions` is optional. If it is None, it means we do not need to check action related issues. + """ + + @abstractmethod + def _shape_check(self, states: np.ndarray, actions: Optional[np.ndarray]) -> bool: + pass + + +class QNetworkMixin(ShapeCheckMixin): + """ + Mixin for policies that have a Q-network in it, no matter how it is used. For example, + both DQN policies and Actor-Critic policies that use a Q-network as the critic should inherit this mixin. + + All concrete classes that inherit `ContinuousActionMixin` should implement the following abstract methods: + - Declared in `ShapeCheckMixin`: + - _shape_check(self, states: np.ndarray, actions: Optional[np.ndarray]) -> bool: + - Declared in `QNetworkMixin`: + - _get_q_values(self, states: np.ndarray, actions: np.ndarray) -> np.ndarray: + """ + + def q_values(self, states: np.ndarray, actions: np.ndarray) -> np.ndarray: + """Returns Q-values based on given states and actions. + The actual logics should be implemented in `_get_q_values`. + + Args: + states (np.ndarray): states with shape [batch_size, state_dim] + actions (np.ndarray): actions with shape [batch_size, action_dim] + + Returns: + Q-values (np.ndarray) with shape [batch_size] + """ + assert self._shape_check(states, actions) + ret = self._get_q_values(states, actions) + assert match_shape(ret, (states.shape[0],)) + return ret + + @abstractmethod + def _get_q_values(self, states: np.ndarray, actions: np.ndarray) -> np.ndarray: + """Implementation of `q_values`.""" + pass + + +class DiscreteQNetworkMixin(DiscreteActionMixin, QNetworkMixin): + """ + Combination of DiscreteActionMixin and QNetworkMixin. + + All concrete classes that inherit `DiscreteQNetworkMixin` should implement the following abstract methods: + - Declared in `ShapeCheckMixin`: + - _shape_check(self, states: np.ndarray, actions: Optional[np.ndarray]) -> bool: + - Declared in `QNetworkMixin`: + - _get_q_values(self, states: np.ndarray, actions: np.ndarray) -> np.ndarray: + - Declared in `DiscreteActionMixin`: + - _get_action_num(self) -> int: + - Declared in `DiscreteQNetworkMixin`: + - _get_q_values_for_all_actions(self, states: np.ndarray) -> np.ndarray: + """ + + def q_values_for_all_actions(self, states: np.ndarray) -> np.ndarray: + """ + Returns Q-values for all actions based on given states. + The actual logics should be implemented in `_get_q_values_for_all_actions`. + + Args: + states (np.ndarray): States with shape [batch_size, state_dim] + + Returns: + Q-values (np.ndarray) with shape [batch_size, action_num] + """ + assert self._shape_check(states, None) + ret = self._get_q_values_for_all_actions(states) + assert match_shape(ret, (states.shape[0], self.action_num)) + return ret + + @abstractmethod + def _get_q_values_for_all_actions(self, states: np.ndarray) -> np.ndarray: + """Implementation of `q_values_for_all_actions`.""" + pass + + +class VNetworkMixin(ShapeCheckMixin): + """ + Mixin for policies that have a V-network in it. Similar to QNetworkMixin. + + All concrete classes that inherit `VNetworkMixin` should implement the following abstract methods: + - Declared in `ShapeCheckMixin`: + - _shape_check(self, states: np.ndarray, actions: Optional[np.ndarray]) -> bool: + - Declared in `VNetworkMixin`: + - _get_v_values(self, states: np.ndarray) -> np.ndarray: + """ + + def v_values(self, states: np.ndarray) -> np.ndarray: + """ + Returns Q-values based on given states. + The actual logics should be implemented in `_get_v_values`. + + Args: + states (np.ndarray): [batch_size, state_dim] + + Returns: + V-values (np.ndarray): [batch_size] + """ + assert self._shape_check(states, None) + ret = self._get_v_values(states) + assert match_shape(ret, (states.shape[0],)) + return ret + + @abstractmethod + def _get_v_values(self, states: np.ndarray) -> np.ndarray: + """Implementation of `v_values`.""" + pass diff --git a/maro/rl/policy_v2/replay.py b/maro/rl/policy_v2/replay.py new file mode 100644 index 000000000..0b6b9e844 --- /dev/null +++ b/maro/rl/policy_v2/replay.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import numpy as np + + +class ReplayMemory: + """Storage facility for simulation experiences. + + This implementation uses a dictionary of lists as the internal data structure. The objects for each key + are stored in a list. + + Args: + capacity (int): Maximum number of experiences that can be stored. + state_dim (int): Dimension of flattened state. + action_dim (int): Action dimension. Defaults to 1. + random_overwrite (bool): This specifies overwrite behavior when the capacity is reached. If this is True, + overwrite positions will be selected randomly. Otherwise, overwrites will occur sequentially with + wrap-around. Defaults to False. + """ + def __init__(self, capacity: int, state_dim: int, action_dim: int = 1, random_overwrite: bool = False): + super().__init__() + self._state_dim = state_dim + self._action_dim = action_dim + self._capacity = capacity + self._random_overwrite = random_overwrite + self.states = np.zeros((self._capacity, self._state_dim), dtype=np.float32) + if action_dim > 1: + self.actions = np.zeros((self._capacity, self._action_dim), dtype=np.float32) + else: + self.actions = np.zeros(self._capacity, dtype=np.int64) + self.rewards = np.zeros(self._capacity, dtype=np.float32) + self.next_states = np.zeros((self._capacity, self._state_dim), dtype=np.float32) + self.terminals = np.zeros(self._capacity, dtype=np.bool) + self._ptr = 0 + + @property + def capacity(self): + """Capacity of the memory.""" + return self._capacity + + @property + def random_overwrite(self): + """Overwrite method after the memory has reached capacity.""" + return self._random_overwrite + + @property + def size(self): + """Current number of experiences stored.""" + return self._ptr + + def put( + self, + states: np.ndarray, + actions: np.ndarray, + rewards: np.ndarray, + next_states: np.ndarray, + terminals: np.ndarray + ): + """Put SARS and terminal flags in the memory.""" + assert len(states) == len(actions) == len(rewards) == len(next_states) == len(terminals) + added = len(states) + if added > self._capacity: + raise ValueError("size of added items should not exceed the capacity.") + + if self._ptr + added <= self._capacity: + indexes = np.arange(self._ptr, self._ptr + added) + # follow the overwrite rule set at init + else: + overwrites = self._ptr + added - self._capacity + indexes = np.concatenate([ + np.arange(self._ptr, self._capacity), + np.random.choice(self._ptr, size=overwrites, replace=False) if self._random_overwrite + else np.arange(overwrites) + ]) + + self.states[indexes] = states + self.actions[indexes] = actions + self.rewards[indexes] = rewards + self.next_states[indexes] = next_states + + self._ptr = min(self._ptr + added, self._capacity) + return indexes + + def sample(self, size: int) -> dict: + """Obtain a random sample.""" + indexes = np.random.choice(self._ptr, size=size) + return { + "states": self.states[indexes], + "actions": self.actions[indexes], + "rewards": self.rewards[indexes], + "next_states": self.next_states[indexes], + "terminals": self.terminals[indexes] + } diff --git a/maro/rl/utils/__init__.py b/maro/rl/utils/__init__.py index 8f21f4c52..2cc7a4fe8 100644 --- a/maro/rl/utils/__init__.py +++ b/maro/rl/utils/__init__.py @@ -3,6 +3,7 @@ from .gradient_averaging import average_grads from .message_enums import MsgKey, MsgTag +from .torch_util import match_shape from .trajectory_computation import discount_cumsum -__all__ = ["MsgKey", "MsgTag", "average_grads", "discount_cumsum"] +__all__ = ["MsgKey", "MsgTag", "average_grads", "discount_cumsum", "match_shape"] diff --git a/maro/rl/utils/torch_util.py b/maro/rl/utils/torch_util.py new file mode 100644 index 000000000..4971a982b --- /dev/null +++ b/maro/rl/utils/torch_util.py @@ -0,0 +1,23 @@ +from typing import Union + +import numpy as np +import torch + + +def match_shape(tensor: Union[torch.Tensor, np.ndarray], shape: tuple) -> bool: + """Check if a torch.Tensor/np.ndarray could match the expected shape. + + Args: + tensor: torch.Tensor or np.ndarray + shape: The expected shape tuple. If an element in this tuple is None, it means this dimension could match any + value (usually used for the `batch_size` dimension). + + Returns: + Whether the tensor could match the expected shape. + """ + if len(tensor.shape) != len(shape): + return False + for val, expected in zip(tensor.shape, shape): + if expected is not None and expected != val: + return False + return True