From 8e4ad49bf7cedf9d0b6a1d13203e38340ac15768 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 21 Dec 2021 14:10:28 +0800 Subject: [PATCH] Preparation for data parallel (#420) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * Rename train worker to train ops; add placeholder for abstract methods; * Lint Co-authored-by: GQ.Chen --- maro/rl/data_parallelism/task_queue.py | 31 +- maro/rl/utils/message_enums.py | 2 + maro/rl_v3/learning/rollout_manager.py | 1 + maro/rl_v3/learning/trainer_manager.py | 3 +- maro/rl_v3/model/abs_net.py | 7 +- maro/rl_v3/model/multi_q_net.py | 1 + maro/rl_v3/model/policy_net.py | 3 +- maro/rl_v3/model/q_net.py | 1 + maro/rl_v3/model/v_net.py | 1 + maro/rl_v3/policy/abs_policy.py | 7 +- maro/rl_v3/policy/continuous_rl_policy.py | 7 +- maro/rl_v3/policy/discrete_rl_policy.py | 13 +- maro/rl_v3/policy_trainer/__init__.py | 9 +- maro/rl_v3/policy_trainer/abs_train_ops.py | 189 ++++++++ maro/rl_v3/policy_trainer/abs_trainer.py | 111 +++-- maro/rl_v3/policy_trainer/ac.py | 242 +++++++--- maro/rl_v3/policy_trainer/ddpg.py | 248 +++++++--- maro/rl_v3/policy_trainer/discrete_maddpg.py | 224 --------- .../distributed_discrete_maddpg.py | 455 ++++++++++++++++++ maro/rl_v3/policy_trainer/dqn.py | 180 +++++-- maro/rl_v3/policy_trainer/maac.py | 158 ------ maro/rl_v3/replay_memory/__init__.py | 4 +- maro/rl_v3/replay_memory/replay_memory.py | 2 +- maro/rl_v3/tmp_example_multi/env_sampler.py | 1 + maro/rl_v3/tmp_example_multi/main.py | 1 + maro/rl_v3/tmp_example_multi/nets.py | 20 +- maro/rl_v3/tmp_example_multi/policies.py | 17 +- maro/rl_v3/tmp_example_single/env_sampler.py | 1 + maro/rl_v3/tmp_example_single/main.py | 1 + maro/rl_v3/tmp_example_single/nets.py | 31 +- maro/rl_v3/tmp_example_single/policies.py | 1 + maro/rl_v3/workflow.py | 2 +- maro/rl_v3/workflows/grad_worker.py | 70 +++ 33 files changed, 1374 insertions(+), 670 deletions(-) create mode 100644 maro/rl_v3/policy_trainer/abs_train_ops.py delete mode 100644 maro/rl_v3/policy_trainer/discrete_maddpg.py create mode 100644 maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py delete mode 100644 maro/rl_v3/policy_trainer/maac.py create mode 100644 maro/rl_v3/workflows/grad_worker.py diff --git a/maro/rl/data_parallelism/task_queue.py b/maro/rl/data_parallelism/task_queue.py index 7f88e916d..2efb46597 100644 --- a/maro/rl/data_parallelism/task_queue.py +++ b/maro/rl/data_parallelism/task_queue.py @@ -5,6 +5,8 @@ from multiprocessing import Manager, Process, Queue, managers from typing import Dict, List +import torch + from maro.communication import Proxy, SessionMessage from maro.rl.utils import MsgKey, MsgTag from maro.utils import DummyLogger, Logger @@ -34,13 +36,21 @@ def request_workers(self, task_queue_server_name="TASK_QUEUE"): return worker_list # TODO: rename this method - def submit(self, worker_id_list: List, batch_list: List, policy_state: Dict, policy_name: str): - """Learn a batch of data on several grad workers.""" + def submit( + self, worker_id_list: List, batch_list: List, tensor_dict_list: List, policy_state: Dict, policy_name: str, + scope: str = None + ) -> Dict[str, List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]]: + """Learn a batch of data on several grad workers. + For each policy, send a list of batch and state to grad workers, and receive a list of gradients. + The results is actually from train worker's `get_batch_grad()` method, with type: + Dict[str, Dict[int, Dict[str, torch.Tensor]]], which means {scope: {worker_id: {param_name: grad_value}}}""" msg_dict = defaultdict(lambda: defaultdict(dict)) loss_info_by_policy = {policy_name: []} - for worker_id, batch in zip(worker_id_list, batch_list): + for worker_id, batch, tensor_dict in zip(worker_id_list, batch_list, tensor_dict_list): msg_dict[worker_id][MsgKey.GRAD_TASK][policy_name] = batch + msg_dict[worker_id][MsgKey.TENSOR][policy_name] = tensor_dict msg_dict[worker_id][MsgKey.POLICY_STATE][policy_name] = policy_state + msg_dict[worker_id][MsgKey.GRAD_SCOPE][policy_name] = scope # data-parallel by multiple remote gradient workers self._proxy.isend(SessionMessage( MsgTag.COMPUTE_GRAD, self._proxy.name, worker_id, body=msg_dict[worker_id])) @@ -73,6 +83,21 @@ def task_queue( proxy_kwargs: dict = {}, logger: Logger = DummyLogger() ): + """The queue to manage data parallel tasks. Task queue communicates with gradient workers, + maintaing the busy/idle status of workers. Clients send requests to task queue, and task queue + will assign available workers to the requests. Task queue follows the `producer-consumer` model, + consisting of two queues: task_pending, task_assigned. Besides, task queue supports task priority, + adding/deleting workers. + + Args: + worker_ids (List[int]): Worker ids to initialize. + num_hosts (int): The number of policy hosts. Will be renamed in RL v3. + num_policies (int): The number of policies. + single_task_limit (float): The limit resource proportion for a single task to assign. Defaults to 0.5 + group (str): Group name to initialize proxy. Defaults to DEFAULT_POLICY_GROUP. + proxy_kwargs (dict): Keyword arguments for proxy. Defaults to empty dict. + logger (Logger): Defaults to DummyLogger(). + """ num_workers = len(worker_ids) if num_hosts == 0: # for multi-process mode diff --git a/maro/rl/utils/message_enums.py b/maro/rl/utils/message_enums.py index a4c317059..d1d57e0af 100644 --- a/maro/rl/utils/message_enums.py +++ b/maro/rl/utils/message_enums.py @@ -38,8 +38,10 @@ class MsgKey(Enum): ROLLOUT_INFO = "rollout_info" TRACKER = "tracker" GRAD_TASK = "grad_task" + GRAD_SCOPE = "grad_scope" LOSS_INFO = "loss_info" STATE = "state" + TENSOR = "tensor" POLICY_STATE = "policy_state" EXPLORATION_STEP = "exploration_step" VERSION = "version" diff --git a/maro/rl_v3/learning/rollout_manager.py b/maro/rl_v3/learning/rollout_manager.py index 0df4990ba..5fd46cbf4 100644 --- a/maro/rl_v3/learning/rollout_manager.py +++ b/maro/rl_v3/learning/rollout_manager.py @@ -12,6 +12,7 @@ from maro.communication import Proxy, SessionType from maro.rl.utils import MsgKey, MsgTag from maro.utils import DummyLogger, Logger, set_seeds + from .env_sampler import AbsEnvSampler, ExpElement diff --git a/maro/rl_v3/learning/trainer_manager.py b/maro/rl_v3/learning/trainer_manager.py index aa2efce93..071886eeb 100644 --- a/maro/rl_v3/learning/trainer_manager.py +++ b/maro/rl_v3/learning/trainer_manager.py @@ -7,6 +7,7 @@ from maro.rl_v3.policy import RLPolicy from maro.rl_v3.policy_trainer import AbsTrainer, MultiTrainer, SingleTrainer from maro.rl_v3.utils import MultiTransitionBatch, TransitionBatch + from .env_sampler import ExpElement @@ -164,7 +165,7 @@ def _dispatch_experience(self, exp_element: ExpElement) -> None: next_states=np.expand_dims(next_agent_state, axis=0), terminals=np.array([terminal]) ) - trainer.record(policy_name=policy_name, transition_batch=batch) + trainer.record(transition_batch=batch) elif isinstance(trainer, MultiTrainer): policy_names: List[str] = [] actions: List[np.ndarray] = [] diff --git a/maro/rl_v3/model/abs_net.py b/maro/rl_v3/model/abs_net.py index 5833c8d33..6ffa20a80 100644 --- a/maro/rl_v3/model/abs_net.py +++ b/maro/rl_v3/model/abs_net.py @@ -13,7 +13,6 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta): def __init__(self) -> None: super(AbsNet, self).__init__() - @abstractmethod def step(self, loss: torch.Tensor) -> None: """ Run a training step to update its own parameters according to the given loss. @@ -21,7 +20,7 @@ def step(self, loss: torch.Tensor) -> None: Args: loss (torch.tensor): Loss used to update the model. """ - raise NotImplementedError + self.apply_gradients(self.get_gradients(loss)) @abstractmethod def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: @@ -36,6 +35,10 @@ def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: """ raise NotImplementedError + @abstractmethod + def apply_gradients(self, grad: dict) -> None: + raise NotImplementedError + def _forward_unimplemented(self, *input: Any) -> None: # TODO pass diff --git a/maro/rl_v3/model/multi_q_net.py b/maro/rl_v3/model/multi_q_net.py index 487c740b7..ba2610198 100644 --- a/maro/rl_v3/model/multi_q_net.py +++ b/maro/rl_v3/model/multi_q_net.py @@ -4,6 +4,7 @@ import torch from maro.rl_v3.utils import SHAPE_CHECK_FLAG, match_shape + from .abs_net import AbsNet diff --git a/maro/rl_v3/model/policy_net.py b/maro/rl_v3/model/policy_net.py index 7f12ad994..ff11b6fec 100644 --- a/maro/rl_v3/model/policy_net.py +++ b/maro/rl_v3/model/policy_net.py @@ -1,10 +1,11 @@ from abc import ABCMeta, abstractmethod -from typing import Optional, Tuple +from typing import Optional import torch.nn from torch.distributions import Categorical from maro.rl_v3.utils import SHAPE_CHECK_FLAG, match_shape + from .abs_net import AbsNet diff --git a/maro/rl_v3/model/q_net.py b/maro/rl_v3/model/q_net.py index 1eafe33b9..612b85df8 100644 --- a/maro/rl_v3/model/q_net.py +++ b/maro/rl_v3/model/q_net.py @@ -4,6 +4,7 @@ import torch from maro.rl_v3.utils import SHAPE_CHECK_FLAG, match_shape + from .abs_net import AbsNet diff --git a/maro/rl_v3/model/v_net.py b/maro/rl_v3/model/v_net.py index c9dc42d4e..ce88a77ac 100644 --- a/maro/rl_v3/model/v_net.py +++ b/maro/rl_v3/model/v_net.py @@ -3,6 +3,7 @@ import torch from maro.rl_v3.utils import SHAPE_CHECK_FLAG, match_shape + from .abs_net import AbsNet diff --git a/maro/rl_v3/policy/abs_policy.py b/maro/rl_v3/policy/abs_policy.py index 2feb481b1..36e29ef7c 100644 --- a/maro/rl_v3/policy/abs_policy.py +++ b/maro/rl_v3/policy/abs_policy.py @@ -138,8 +138,9 @@ def step(self, loss: torch.Tensor) -> None: Args: loss (torch.Tensor): Loss used to update the policy. """ - raise NotImplementedError + self.apply_gradients(self.get_gradients(loss)) + @abstractmethod def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: """ Get the gradients with respect to all parameters of the internal nets according to the given loss. @@ -152,6 +153,10 @@ def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: """ raise NotImplementedError + @abstractmethod + def apply_gradients(self, grad: dict) -> None: + raise NotImplementedError + def get_actions(self, states: np.ndarray) -> np.ndarray: return self.get_actions_tensor(ndarray_to_tensor(states, self._device)).cpu().numpy() diff --git a/maro/rl_v3/policy/continuous_rl_policy.py b/maro/rl_v3/policy/continuous_rl_policy.py index f66447392..c147f7da6 100644 --- a/maro/rl_v3/policy/continuous_rl_policy.py +++ b/maro/rl_v3/policy/continuous_rl_policy.py @@ -4,6 +4,7 @@ import torch from maro.rl_v3.model import ContinuousPolicyNet + from .abs_policy import RLPolicy @@ -74,12 +75,12 @@ def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool: def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: return self._policy_net.get_actions(states, exploring) - def step(self, loss: torch.Tensor) -> None: - self._policy_net.step(loss) - def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: return self._policy_net.get_gradients(loss) + def apply_gradients(self, grad: dict) -> None: + self._policy_net.apply_gradients(grad) + def freeze(self) -> None: self._policy_net.freeze() diff --git a/maro/rl_v3/policy/discrete_rl_policy.py b/maro/rl_v3/policy/discrete_rl_policy.py index d8c302b43..82c496928 100644 --- a/maro/rl_v3/policy/discrete_rl_policy.py +++ b/maro/rl_v3/policy/discrete_rl_policy.py @@ -8,6 +8,7 @@ from maro.rl_v3.model import DiscretePolicyNet, DiscreteQNet from maro.rl_v3.utils import match_shape, ndarray_to_tensor from maro.utils import clone + from .abs_policy import RLPolicy @@ -105,12 +106,12 @@ def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tens actions = ndarray_to_tensor(actions, self._device) return actions.unsqueeze(1) # [B, 1] - def step(self, loss: torch.Tensor) -> None: - self._q_net.step(loss) - def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: return self._q_net.get_gradients(loss) + def apply_gradients(self, grad: dict) -> None: + self._q_net.apply_gradients(grad) + def freeze(self) -> None: self._q_net.freeze() @@ -163,12 +164,12 @@ def policy_net(self) -> DiscretePolicyNet: def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: return self._policy_net.get_actions(states, exploring) - def step(self, loss: torch.Tensor) -> None: - self._policy_net.step(loss) - def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: return self._policy_net.get_gradients(loss) + def apply_gradients(self, grad: dict) -> None: + self._policy_net.apply_gradients(grad) + def freeze(self) -> None: self._policy_net.freeze() diff --git a/maro/rl_v3/policy_trainer/__init__.py b/maro/rl_v3/policy_trainer/__init__.py index 65690de80..32cebf188 100644 --- a/maro/rl_v3/policy_trainer/__init__.py +++ b/maro/rl_v3/policy_trainer/__init__.py @@ -1,16 +1,13 @@ from .abs_trainer import AbsTrainer, MultiTrainer, SingleTrainer from .ac import DiscreteActorCritic from .ddpg import DDPG -from .discrete_maddpg import DiscreteMADDPG +from .distributed_discrete_maddpg import DistributedDiscreteMADDPG from .dqn import DQN -from .maac import DiscreteMultiActorCritic - __all__ = [ "AbsTrainer", "MultiTrainer", "SingleTrainer", "DiscreteActorCritic", "DDPG", - "DQN", - "DiscreteMultiActorCritic", - "DiscreteMADDPG" + "DistributedDiscreteMADDPG", + "DQN" ] diff --git a/maro/rl_v3/policy_trainer/abs_train_ops.py b/maro/rl_v3/policy_trainer/abs_train_ops.py new file mode 100644 index 000000000..1fc202197 --- /dev/null +++ b/maro/rl_v3/policy_trainer/abs_train_ops.py @@ -0,0 +1,189 @@ +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Optional, Union + +import torch + +from maro.communication import Proxy +from maro.rl.data_parallelism import TaskQueueClient +from maro.rl.utils import average_grads +from maro.rl_v3.policy import RLPolicy +from maro.rl_v3.utils import MultiTransitionBatch, TransitionBatch + + +class AbsTrainOps(object, metaclass=ABCMeta): + """The basic component for training a policy, which mainly takes charge of gradient computation and policy update. + In trainer, train worker hosts a policy, and trainer hosts several train workers. In gradient workers, + the train worker is an atomic representation of a policy, to perform parallel gradient computing. + """ + def __init__( + self, + name: str, + device: torch.device, + enable_data_parallelism: bool = False + ) -> None: + super(AbsTrainOps, self).__init__() + self._name = name + self._enable_data_parallelism = enable_data_parallelism + self._task_queue_client: Optional[TaskQueueClient] = None + self._device = device + + @property + def name(self) -> str: + return self._name + + def _get_batch_grad( + self, + batch: Union[TransitionBatch, MultiTransitionBatch], + tensor_dict: Dict[str, object] = None, + scope: str = "all" + ) -> Dict[str, Dict[int, Dict[str, torch.Tensor]]]: + if self._enable_data_parallelism: + gradients = self._remote_learn(batch, tensor_dict, scope) + return average_grads(gradients) + else: + return self.get_batch_grad(batch, tensor_dict, scope) + + def _remote_learn( + self, + batch: Union[TransitionBatch, MultiTransitionBatch], + tensor_dict: Dict[str, object] = None, + scope: str = "all" + ) -> List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]: + """Learn a batch of experience data from remote gradient workers. + The task queue client will first request available gradient workers from task queue. If all workers are busy, + it will keep waiting until at least 1 worker is available. Then the task queue client submits batch and state + to the assigned workers to compute gradients. + """ + assert self._task_queue_client is not None + worker_id_list = self._task_queue_client.request_workers() + batch_list = self._dispatch_batch(batch, len(worker_id_list)) + # TODO: implement _dispatch_tensor_dict + tensor_dict_list = self._dispatch_tensor_dict(tensor_dict, len(worker_id_list)) + ops_state = self.get_ops_state_dict() + ops_name = self.name + loss_info_by_name = self._task_queue_client.submit( + worker_id_list, batch_list, tensor_dict_list, ops_state, ops_name, scope) + return loss_info_by_name[ops_name] + + @abstractmethod + def get_batch_grad( + self, + batch: Union[TransitionBatch, MultiTransitionBatch], + tensor_dict: Dict[str, object] = None, + scope: str = "all" + ) -> Dict[str, Dict[int, Dict[str, torch.Tensor]]]: + raise NotImplementedError + + @abstractmethod + def _dispatch_batch( + self, + batch: Union[TransitionBatch, MultiTransitionBatch], + num_ops: int + ) -> Union[List[TransitionBatch], List[MultiTransitionBatch]]: + """Split experience data batch to several parts. + For on-policy algorithms, like PG, the batch is splitted into several complete trajectories. + For off-policy algorithms, like DQN, the batch is treated as independent data points and splitted evenly.""" + raise NotImplementedError + + @abstractmethod + def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_ops: int) -> List[Dict[str, object]]: + raise NotImplementedError + + def init_data_parallel(self, *args, **kwargs) -> None: + """ + Initialize a proxy in the policy, for data-parallel training. + Using the same arguments as `Proxy`. + """ + self._task_queue_client = TaskQueueClient() + self._task_queue_client.create_proxy(*args, **kwargs) + + def init_data_parallel_with_existing_proxy(self, proxy: Proxy) -> None: + """ + Initialize a proxy in the policy with an existing one, for data-parallel training. + """ + self._task_queue_client = TaskQueueClient() + self._task_queue_client.set_proxy(proxy) + + def exit_data_parallel(self) -> None: + if self._task_queue_client is not None: + self._task_queue_client.exit() + self._task_queue_client = None + + @abstractmethod + def get_ops_state_dict(self, scope: str = "all") -> dict: + """ + Returns: + A dict that contains ops's state. + """ + raise NotImplementedError + + @abstractmethod + def set_ops_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None: + """Set ops's state.""" + raise NotImplementedError + + +class SingleTrainOps(AbsTrainOps, metaclass=ABCMeta): + def __init__( + self, + name: str, + device: torch.device, + enable_data_parallelism: bool = False + ) -> None: + super(SingleTrainOps, self).__init__(name, device, enable_data_parallelism) + self._batch: Optional[TransitionBatch] = None + self._policy: Optional[RLPolicy] = None + + def register_policy(self, policy: RLPolicy) -> None: + policy.to_device(self._device) + self._register_policy_impl(policy) + + @abstractmethod + def _register_policy_impl(self, policy: RLPolicy) -> None: + raise NotImplementedError + + def set_batch(self, batch: TransitionBatch) -> None: + self._batch = batch + + def get_policy_state(self) -> object: + return self._policy.get_policy_state() + + def set_policy_state(self, policy_state: object) -> None: + self._policy.set_policy_state(policy_state) + + +class MultiTrainOps(AbsTrainOps, metaclass=ABCMeta): + def __init__( + self, + name: str, + device: torch.device, + enable_data_parallelism: bool = False + ) -> None: + super(MultiTrainOps, self).__init__(name, device, enable_data_parallelism) + self._batch: Optional[MultiTransitionBatch] = None + self._policies: Dict[int, RLPolicy] = {} + self._indexes: List[int] = [] + + @property + def num_policies(self) -> int: + return len(self._policies) + + def register_policies(self, policy_dict: Dict[int, RLPolicy]) -> None: + self._indexes = list(policy_dict.keys()) + for policy in policy_dict.values(): + policy.to_device(self._device) + self._register_policies_impl(policy_dict) + + @abstractmethod + def _register_policies_impl(self, policy_dict: Dict[int, RLPolicy]) -> None: + raise NotImplementedError + + def set_batch(self, batch: MultiTransitionBatch) -> None: + self._batch = batch + + def get_policy_state_dict(self) -> dict: + return {i: policy.get_policy_state() for i, policy in self._policies.items()} + + def set_policy_state_dict(self, policy_state_dict: dict) -> None: + for i, policy in self._policies.items(): + policy.set_policy_state(policy_state_dict[i]) diff --git a/maro/rl_v3/policy_trainer/abs_trainer.py b/maro/rl_v3/policy_trainer/abs_trainer.py index 3e68b09bc..c336cd058 100644 --- a/maro/rl_v3/policy_trainer/abs_trainer.py +++ b/maro/rl_v3/policy_trainer/abs_trainer.py @@ -7,15 +7,33 @@ from maro.rl_v3.replay_memory import MultiReplayMemory, ReplayMemory from maro.rl_v3.utils import MultiTransitionBatch, TransitionBatch +from .abs_train_ops import SingleTrainOps + class AbsTrainer(object, metaclass=ABCMeta): + """Policy trainer used to train policies. Trainer maintains several train workers and + controls training logics of them, while train workers take charge of specific policy updating. """ - Policy trainer used to train policies. - """ - def __init__(self, name: str, device: str = None) -> None: + def __init__( + self, + name: str, + device: str = None, + enable_data_parallelism: bool = False, + train_batch_size: int = 128 + ) -> None: + """ + Args: + name (str): Name of the trainer + device (str): Device to store this trainer. If it is None, the device will be set to "cpu" if cuda is + unavailable and "cuda" otherwise. Defaults to None. + enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. + train_batch_size (int): train batch size. + """ self._name = name self._device = torch.device(device) if device is not None \ else torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._enable_data_parallelism = enable_data_parallelism + self._train_batch_size = train_batch_size print(f"Creating trainer {self.__class__.__name__} {name} on device {self._device}") @@ -55,39 +73,35 @@ class SingleTrainer(AbsTrainer, metaclass=ABCMeta): """ Policy trainer that trains only one policy. """ - def __init__(self, name: str, device: str = None) -> None: - super(SingleTrainer, self).__init__(name, device) - self._policy: Optional[RLPolicy] = None - self._replay_memory = Optional[ReplayMemory] - - def record( + def __init__( self, - policy_name: str, # TODO: need this? - transition_batch: TransitionBatch + name: str, + device: str = None, + enable_data_parallelism: bool = False, + train_batch_size: int = 128 ) -> None: + super(SingleTrainer, self).__init__(name, device, enable_data_parallelism, train_batch_size) + self._policy_name: Optional[str] = None + self._replay_memory: Optional[ReplayMemory] = None + self._ops: Optional[SingleTrainOps] = None + + def record(self, transition_batch: TransitionBatch) -> None: """ Record the experiences collected by external modules. Args: - policy_name (str): The name of the policy that generates this batch. transition_batch (TransitionBatch): A TransitionBatch item that contains a batch of experiences. """ - self._record_impl( - policy_name=policy_name, - transition_batch=transition_batch - ) - - def _record_impl(self, policy_name: str, transition_batch: TransitionBatch) -> None: - """ - Implementation of `record`. - """ self._replay_memory.put(transition_batch) + def _get_batch(self, batch_size: int = None) -> TransitionBatch: + return self._replay_memory.sample(batch_size if batch_size is not None else self._train_batch_size) + def register_policy(self, policy: RLPolicy) -> None: """ Register the policy and finish other related initializations. """ - policy.to_device(self._device) + self._policy_name = policy.name self._register_policy_impl(policy) @abstractmethod @@ -95,57 +109,52 @@ def _register_policy_impl(self, policy: RLPolicy) -> None: raise NotImplementedError def get_policy_state_dict(self) -> Dict[str, object]: - return {self._policy.name: self._policy.get_policy_state()} + return {self._policy_name: self._ops.get_policy_state()} def set_policy_state_dict(self, policy_state_dict: Dict[str, object]) -> None: - assert len(policy_state_dict) == 1 and self._policy.name in policy_state_dict - self._policy.set_policy_state(policy_state_dict[self._policy.name]) + assert len(policy_state_dict) == 1 and self._policy_name in policy_state_dict + self._ops.set_policy_state(policy_state_dict[self._policy_name]) class MultiTrainer(AbsTrainer, metaclass=ABCMeta): """ Policy trainer that trains multiple policies. """ - def __init__(self, name: str, device: str = None) -> None: - super(MultiTrainer, self).__init__(name, device) - self._policy_dict: Dict[str, RLPolicy] = {} - self._policies: List[RLPolicy] = [] + + def __init__( + self, + name: str, + device: str = None, + enable_data_parallelism: bool = False, + train_batch_size: int = 128 + ) -> None: + super(MultiTrainer, self).__init__(name, device, enable_data_parallelism, train_batch_size) + self._policy_names: List[str] = [] self._replay_memory: Optional[MultiReplayMemory] = None @property - def num_policies(self): - return len(self._policies) + def num_policies(self) -> int: + return len(self._policy_names) - def record( - self, - transition_batch: MultiTransitionBatch - ) -> None: + def record(self, transition_batch: MultiTransitionBatch) -> None: """ Record the experiences collected by external modules. Args: transition_batch (MultiTransitionBatch): A TransitionBatch item that contains a batch of experiences. """ - self._record_impl(transition_batch) + self._replay_memory.put(transition_batch) - @abstractmethod - def _record_impl(self, transition_batch: MultiTransitionBatch) -> None: - raise NotImplementedError + def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch: + return self._replay_memory.sample(batch_size if batch_size is not None else self._train_batch_size) def register_policies(self, policies: List[RLPolicy]) -> None: - for policy in policies: - policy.to_device(self._device) + """ + Register the policies and finish other related initializations. + """ + self._policy_names = [policy.name for policy in policies] self._register_policies_impl(policies) @abstractmethod def _register_policies_impl(self, policies: List[RLPolicy]) -> None: - pass - - def get_policy_state_dict(self) -> Dict[str, object]: - return {policy_name: policy.get_policy_state() for policy_name, policy in self._policy_dict.items()} - - def set_policy_state_dict(self, policy_state_dict: Dict[str, object]) -> None: - assert len(policy_state_dict) == len(self._policy_dict) - for policy_name, policy_state in policy_state_dict.items(): - assert policy_name in self._policy_dict - self._policy_dict[policy_name].set_policy_state(policy_state) + raise NotImplementedError diff --git a/maro/rl_v3/policy_trainer/ac.py b/maro/rl_v3/policy_trainer/ac.py index b6d069877..270c44e96 100644 --- a/maro/rl_v3/policy_trainer/ac.py +++ b/maro/rl_v3/policy_trainer/ac.py @@ -1,20 +1,173 @@ -from typing import Callable, Optional +from typing import Callable, Dict, List import numpy as np import torch from maro.rl.utils import discount_cumsum from maro.rl_v3.model import VNet -from maro.rl_v3.policy import DiscretePolicyGradient +from maro.rl_v3.policy import DiscretePolicyGradient, RLPolicy from maro.rl_v3.replay_memory import FIFOReplayMemory from maro.rl_v3.utils import TransitionBatch, ndarray_to_tensor -from maro.utils import clone + +from .abs_train_ops import SingleTrainOps from .abs_trainer import SingleTrainer +class DiscreteActorCriticTrainOps(SingleTrainOps): + def __init__( + self, + name: str, + device: torch.device, + get_v_critic_net_func: Callable[[], VNet], + reward_discount: float = 0.9, + critic_loss_coef: float = 0.1, + critic_loss_cls: Callable = None, + clip_ratio: float = None, + lam: float = 0.9, + min_logp: float = None, + enable_data_parallelism: bool = False + ) -> None: + super(DiscreteActorCriticTrainOps, self).__init__(name, device, enable_data_parallelism) + + self._get_v_critic_net_func = get_v_critic_net_func + self._reward_discount = reward_discount + self._critic_loss_coef = critic_loss_coef + self._critic_loss_func = critic_loss_cls() if critic_loss_cls is not None else torch.nn.MSELoss() + self._clip_ratio = clip_ratio + self._lam = lam + self._min_logp = min_logp + + def _register_policy_impl(self, policy: RLPolicy) -> None: + assert isinstance(policy, DiscretePolicyGradient) + + self._policy = policy + self._v_critic_net = self._get_v_critic_net_func() + self._v_critic_net.to(self._device) + + def get_batch_grad( + self, + batch: TransitionBatch, + tensor_dict: Dict[str, object] = None, + scope: str = "all" + ) -> Dict[str, Dict[str, torch.Tensor]]: + """ + Reference: https://tinyurl.com/2ezte4cr + """ + assert scope in ("all", "actor", "critic"), \ + f"Unrecognized scope {scope}. Excepting 'all', 'actor', or 'critic'." + + grad_dict = {} + if scope in ("all", "actor"): + grad_dict["actor_grad"] = self._get_actor_grad(batch) + + if scope in ("all", "critic"): + grad_dict["critic_grad"] = self._get_critic_grad(batch) + + return grad_dict + + def _dispatch_batch(self, batch: TransitionBatch, num_ops: int) -> List[TransitionBatch]: + raise NotImplementedError + + def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_ops: int) -> List[Dict[str, object]]: + raise NotImplementedError + + def _get_critic_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: + states = ndarray_to_tensor(batch.states, self._device) # s + + self._policy.train() + self._v_critic_net.train() + + state_values = self._v_critic_net.v_values(states) + values = state_values.detach().numpy() + values = np.concatenate([values, values[-1:]]) + rewards = np.concatenate([batch.rewards, values[-1:]]) + + returns = ndarray_to_tensor(discount_cumsum(rewards, self._reward_discount)[:-1], self._device) + critic_loss = self._critic_loss_func(state_values, returns) + + return self._v_critic_net.get_gradients(critic_loss * self._critic_loss_coef) + + def _get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: + states = ndarray_to_tensor(batch.states, self._device) # s + actions = ndarray_to_tensor(batch.actions, self._device).long() # a + + if self._clip_ratio is not None: + self._policy.eval() + logps_old = self._policy.get_state_action_logps(states, actions) + else: + logps_old = None + + state_values = self._v_critic_net.v_values(states) + values = state_values.detach().numpy() + values = np.concatenate([values, values[-1:]]) + rewards = np.concatenate([batch.rewards, values[-1:]]) + + deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s) + advantages = ndarray_to_tensor(discount_cumsum(deltas, self._reward_discount * self._lam), self._device) + + action_probs = self._policy.get_action_probs(states) + logps = torch.log(action_probs.gather(1, actions).squeeze()) + logps = torch.clamp(logps, min=self._min_logp, max=.0) + if self._clip_ratio is not None: + ratio = torch.exp(logps - logps_old) + clipped_ratio = torch.clamp(ratio, 1 - self._clip_ratio, 1 + self._clip_ratio) + actor_loss = -(torch.min(ratio * advantages, clipped_ratio * advantages)).mean() + else: + actor_loss = -(logps * advantages).mean() # I * delta * log pi(a|s) + + return self._policy.get_gradients(actor_loss) + + def update(self) -> None: + """ + Reference: https://tinyurl.com/2ezte4cr + """ + grad_dict = self._get_batch_grad(self._batch, scope="all") + self._policy.train() + self._policy.apply_gradients(grad_dict["actor_grad"]) + self._v_critic_net.train() + self._v_critic_net.apply_gradients(grad_dict["critic_grad"]) + + def get_ops_state_dict(self, scope: str = "all") -> dict: + ret_dict = {} + if scope in ("all", "actor"): + ret_dict["policy_state"] = self._policy.get_policy_state() + if scope in ("all", "critic"): + ret_dict["critic_state"] = self._v_critic_net.get_net_state() + return ret_dict + + def set_ops_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None: + if scope in ("all", "actor"): + self._policy.set_policy_state(ops_state_dict["policy_state"]) + if scope in ("all", "critic"): + self._v_critic_net.set_net_state(ops_state_dict["critic_state"]) + + class DiscreteActorCritic(SingleTrainer): - """ - TODO: docs. + """Actor Critic algorithm with separate policy and value models. + + References: + https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch. + https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f + + Args: + name (str): Unique identifier for the policy. + get_v_critic_net_func (Callable[[], VNet]): Function to get V critic net. + policy (DiscretePolicyGradient): The policy to be trained. + replay_memory_capacity (int): Capacity of the replay memory. Defaults to 10000. + train_batch_size (int): Batch size for training the Q-net. Defaults to 128. + grad_iters (int): Number of iterations to calculate gradients. Defaults to 1. + reward_discount (float): Reward decay as defined in standard RL terminology. Defaults to 0.9. + lam (float): Lambda value for generalized advantage estimation (TD-Lambda). Defaults to 0.9. + clip_ratio (float): Clip ratio in the PPO algorithm (https://arxiv.org/pdf/1707.06347.pdf). Defaults to None, + in which case the actor loss is calculated using the usual policy gradient theorem. + critic_loss_cls: A string indicating a loss class provided by torch.nn or a custom loss class for computing + the critic loss. If it is a string, it must be a key in ``TORCH_LOSS``. Defaults to "mse". + min_logp (float): Lower bound for clamping logP values during learning. This is to prevent logP from becoming + very large in magnitude and causing stability issues. Defaults to None, which means no lower bound. + critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 1.0. + device (str): Identifier for the torch device. The policy will be moved to the specified device. If it is + None, the device will be set to "cpu" if cuda is unavailable and "cuda" otherwise. Defaults to None. + enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. Defaults to False. """ def __init__( self, @@ -30,15 +183,19 @@ def __init__( critic_loss_cls: Callable = None, min_logp: float = None, critic_loss_coef: float = 0.1, - device: str = None + device: str = None, + enable_data_parallelism: bool = False ) -> None: - super(DiscreteActorCritic, self).__init__(name, device) + super(DiscreteActorCritic, self).__init__( + name=name, + device=device, + enable_data_parallelism=enable_data_parallelism, + train_batch_size=train_batch_size + ) self._replay_memory_capacity = replay_memory_capacity - self._get_v_net_func = get_v_critic_net_func - self._policy: Optional[DiscretePolicyGradient] = None - self._v_critic_net: Optional[VNet] = None + self._get_v_critic_net_func = get_v_critic_net_func if policy is not None: self.register_policy(policy) @@ -49,66 +206,23 @@ def __init__( self._min_logp = min_logp self._grad_iters = grad_iters self._critic_loss_coef = critic_loss_coef - - self._critic_loss_func = critic_loss_cls() if critic_loss_cls is not None else torch.nn.MSELoss() - - def _record_impl(self, policy_name: str, transition_batch: TransitionBatch) -> None: - self._replay_memory.put(transition_batch) - - def _get_batch(self, batch_size: int = None) -> TransitionBatch: - return self._replay_memory.sample(batch_size if batch_size is not None else self._train_batch_size) + self._critic_loss_cls = critic_loss_cls def _register_policy_impl(self, policy: DiscretePolicyGradient) -> None: - assert isinstance(policy, DiscretePolicyGradient) - self._policy = policy + self._ops = DiscreteActorCriticTrainOps( + name="ops", device=self._device, get_v_critic_net_func=self._get_v_critic_net_func, + reward_discount=self._reward_discount, critic_loss_coef=self._critic_loss_coef, + critic_loss_cls=self._critic_loss_cls, clip_ratio=self._clip_ratio, lam=self._lam, + min_logp=self._min_logp, enable_data_parallelism=self._enable_data_parallelism + ) + self._ops.register_policy(policy) + self._replay_memory = FIFOReplayMemory( capacity=self._replay_memory_capacity, state_dim=policy.state_dim, action_dim=policy.action_dim ) - self._v_critic_net = self._get_v_net_func() - self._v_critic_net.to(self._device) def train_step(self) -> None: - self._improve(self._get_batch()) - - def _improve(self, batch: TransitionBatch) -> None: - """ - Reference: https://tinyurl.com/2ezte4cr - """ - v_critic_net_copy = clone(self._v_critic_net) - v_critic_net_copy.eval() - - states = ndarray_to_tensor(batch.states, self._device) # s - actions = ndarray_to_tensor(batch.actions, self._device).long() # a - - self._policy.eval() - logps_old = self._policy.get_state_action_logps(states, actions) # log pi(a|s), action log-prob when sampling - - self._policy.train() - self._v_critic_net.train() + self._ops.set_batch(self._get_batch()) for _ in range(self._grad_iters): - state_values = self._v_critic_net.v_values(states) - values = state_values.detach().numpy() - values = np.concatenate([values, values[-1:]]) - rewards = np.concatenate([batch.rewards, values[-1:]]) - deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s) - returns = ndarray_to_tensor(discount_cumsum(rewards, self._reward_discount)[:-1], self._device) - advantages = ndarray_to_tensor(discount_cumsum(deltas, self._reward_discount * self._lam), self._device) - - # Critic loss - critic_loss = self._critic_loss_func(state_values, returns) - - # Actor loss - action_probs = self._policy.get_action_probs(states) - logps = torch.log(action_probs.gather(1, actions).squeeze()) - logps = torch.clamp(logps, min=self._min_logp, max=.0) - if self._clip_ratio is not None: - ratio = torch.exp(logps - logps_old) - clipped_ratio = torch.clamp(ratio, 1 - self._clip_ratio, 1 + self._clip_ratio) - actor_loss = -(torch.min(ratio * advantages, clipped_ratio * advantages)).mean() - else: - actor_loss = -(logps * advantages).mean() # I * delta * log pi(a|s) - - # Update - self._policy.step(actor_loss) - self._v_critic_net.step(critic_loss * self._critic_loss_coef) + self._ops.update() diff --git a/maro/rl_v3/policy_trainer/ddpg.py b/maro/rl_v3/policy_trainer/ddpg.py index db58a878b..96f0a9963 100644 --- a/maro/rl_v3/policy_trainer/ddpg.py +++ b/maro/rl_v3/policy_trainer/ddpg.py @@ -1,89 +1,104 @@ -from typing import Callable, Optional +from typing import Callable, Dict, List, Optional import torch from maro.rl_v3.model import QNet -from maro.rl_v3.policy import ContinuousRLPolicy +from maro.rl_v3.policy import ContinuousRLPolicy, RLPolicy from maro.rl_v3.replay_memory import RandomReplayMemory from maro.rl_v3.utils import TransitionBatch, ndarray_to_tensor from maro.utils import clone + +from .abs_train_ops import SingleTrainOps from .abs_trainer import SingleTrainer -class DDPG(SingleTrainer): +class DDPGTrainOps(SingleTrainOps): def __init__( self, name: str, + device: torch.device, get_q_critic_net_func: Callable[[], QNet], reward_discount: float, q_value_loss_cls: Callable = None, - policy: ContinuousRLPolicy = None, - random_overwrite: bool = False, - replay_memory_capacity: int = 1000000, - num_epochs: int = 1, - update_target_every: int = 5, soft_update_coef: float = 1.0, - train_batch_size: int = 32, critic_loss_coef: float = 0.1, - device: str = None + enable_data_parallelism: bool = False ) -> None: - super(DDPG, self).__init__(name=name, device=device) + super(DDPGTrainOps, self).__init__(name, device, enable_data_parallelism) - self._policy: ContinuousRLPolicy = Optional[ContinuousRLPolicy] - self._target_policy: ContinuousRLPolicy = Optional[ContinuousRLPolicy] - self._q_critic_net: QNet = Optional[QNet] - self._target_q_critic_net: QNet = Optional[QNet] + self._policy: Optional[ContinuousRLPolicy] = None + self._target_policy: Optional[ContinuousRLPolicy] = None + self._q_critic_net: Optional[QNet] = None + self._target_q_critic_net: Optional[QNet] = None self._get_q_critic_net_func = get_q_critic_net_func - self._replay_memory_capacity = replay_memory_capacity - self._random_overwrite = random_overwrite - if policy is not None: - self.register_policy(policy) - self._num_epochs = num_epochs - self._policy_ver = self._target_policy_ver = 0 - self._update_target_every = update_target_every - self._soft_update_coef = soft_update_coef - self._train_batch_size = train_batch_size self._reward_discount = reward_discount self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() self._critic_loss_coef = critic_loss_coef + self._soft_update_coef = soft_update_coef - def _record_impl(self, policy_name: str, transition_batch: TransitionBatch) -> None: - self._replay_memory.put(transition_batch) - - def _register_policy_impl(self, policy: ContinuousRLPolicy) -> None: + def _register_policy_impl(self, policy: RLPolicy) -> None: assert isinstance(policy, ContinuousRLPolicy) + self._policy = policy self._target_policy = clone(self._policy) self._target_policy.set_name(f"target_{policy.name}") self._target_policy.eval() - self._replay_memory = RandomReplayMemory( - capacity=self._replay_memory_capacity, state_dim=policy.state_dim, - action_dim=policy.action_dim, random_overwrite=self._random_overwrite - ) + self._target_policy.to_device(self._device) + self._q_critic_net = self._get_q_critic_net_func() + self._q_critic_net.to(self._device) self._target_q_critic_net: QNet = clone(self._q_critic_net) self._target_q_critic_net.eval() - - self._target_policy.to_device(self._device) - self._q_critic_net.to(self._device) self._target_q_critic_net.to(self._device) - def _get_batch(self, batch_size: int = None) -> TransitionBatch: - return self._replay_memory.sample(batch_size if batch_size is not None else self._train_batch_size) - - def train_step(self) -> None: - for _ in range(self._num_epochs): - self._improve(self._get_batch()) - self._update_target_policy() - - def _improve(self, batch: TransitionBatch) -> None: + def get_batch_grad( + self, + batch: TransitionBatch, + tensor_dict: Dict[str, object] = None, + scope: str = "all" + ) -> Dict[str, Dict[str, torch.Tensor]]: """ Reference: https://spinningup.openai.com/en/latest/algorithms/ddpg.html """ + + assert scope in ("all", "actor", "critic"), \ + f"Unrecognized scope {scope}. Excepting 'all', 'actor', or 'critic'." + + grad_dict = {} + if scope in ("all", "critic"): + grad_dict["critic_grad"] = self._get_critic_grad(batch) + + if scope in ("all", "actor"): + grad_dict["actor_grad"] = self._get_actor_grad(batch) + + return grad_dict + + def _dispatch_batch(self, batch: TransitionBatch, num_ops: int) -> List[TransitionBatch]: + raise NotImplementedError + + def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_ops: int) -> List[Dict[str, object]]: + raise NotImplementedError + + def _get_critic_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: + self._q_critic_net.train() self._policy.train() states = ndarray_to_tensor(batch.states, self._device) # s + + policy_loss = -self._q_critic_net.q_values( + states=states, # s + actions=self._policy.get_actions_tensor(states) # miu(s) + ).mean() # -Q(s, miu(s)) + + return self._policy.get_gradients(policy_loss) + + def _get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: + self._q_critic_net.train() + self._policy.train() + + states = ndarray_to_tensor(batch.states, self._device) # s + next_states = ndarray_to_tensor(batch.next_states, self._device) # s' actions = ndarray_to_tensor(batch.actions, self._device) # a rewards = ndarray_to_tensor(batch.rewards, self._device) # r @@ -100,20 +115,133 @@ def _improve(self, batch: TransitionBatch) -> None: q_values = self._q_critic_net.q_values(states=states, actions=actions) # Q(s, a) critic_loss = self._q_value_loss_func(q_values, target_q_values) # MSE(Q(s, a), y(r, s', d)) - policy_loss = -self._q_critic_net.q_values( - states=states, # s - actions=self._policy.get_actions_tensor(states) # miu(s) - ).mean() # -Q(s, miu(s)) - # Update Q first, then freeze Q and update miu. - self._q_critic_net.step(critic_loss * self._critic_loss_coef) - self._q_critic_net.freeze() - self._policy.step(policy_loss) - self._q_critic_net.unfreeze() - - def _update_target_policy(self) -> None: - self._policy_ver += 1 - if self._policy_ver - self._target_policy_ver == self._update_target_every: - self._target_policy.soft_update(self._policy, self._soft_update_coef) - self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef) - self._target_policy_ver = self._policy_ver + return self._q_critic_net.get_gradients(critic_loss * self._critic_loss_coef) + + def update(self) -> None: + grad_dict = self._get_batch_grad(self._batch, scope="critic") + self._q_critic_net.train() + self._q_critic_net.apply_gradients(grad_dict["critic_grad"]) + + grad_dict = self._get_batch_grad(self._batch, scope="actor") + self._policy.train() + self._policy.apply_gradients(grad_dict["actor_grad"]) + + def get_ops_state_dict(self, scope: str = "all") -> dict: + ret_dict = {} + if scope in ("all", "actor"): + ret_dict["policy_state"] = self._policy.get_policy_state() + ret_dict["target_policy_state"] = self._target_policy.get_policy_state() + if scope in ("all", "critic"): + ret_dict["critic_state"] = self._q_critic_net.get_net_state() + ret_dict["target_critic_state"] = self._target_q_critic_net.get_net_state() + return ret_dict + + def set_ops_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None: + if scope in ("all", "actor"): + self._policy.set_policy_state(ops_state_dict["policy_state"]) + self._target_policy.set_policy_state(ops_state_dict["target_policy_state"]) + if scope in ("all", "critic"): + self._q_critic_net.set_net_state(ops_state_dict["critic_state"]) + self._target_q_critic_net.set_net_state(ops_state_dict["target_critic_state"]) + + def soft_update_target(self) -> None: + self._target_policy.soft_update(self._policy, self._soft_update_coef) + self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef) + + +class DDPG(SingleTrainer): + """The Deep Deterministic Policy Gradient (DDPG) algorithm. + + References: + https://arxiv.org/pdf/1509.02971.pdf + https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg + + Args: + name (str): Unique identifier for the policy. + get_q_critic_net_func (Callable[[], QNet]): Function to get Q critic net. + reward_discount (float): Reward decay as defined in standard RL terminology. + q_value_loss_cls: A string indicating a loss class provided by torch.nn or a custom loss class for + the Q-value loss. If it is a string, it must be a key in ``TORCH_LOSS``. Defaults to "mse". + policy (DiscretePolicyGradient): The policy to be trained. + random_overwrite (bool): This specifies overwrite behavior when the replay memory capacity is reached. If True, + overwrite positions will be selected randomly. Otherwise, overwrites will occur sequentially with + wrap-around. Defaults to False. + replay_memory_capacity (int): Capacity of the replay memory. Defaults to 10000. + num_epochs (int): Number of training epochs per call to ``learn``. Defaults to 1. + update_target_every (int): Number of training rounds between policy target model updates. + soft_update_coef (float): Soft update coefficient, e.g., target_model = (soft_update_coef) * eval_model + + (1-soft_update_coef) * target_model. Defaults to 1.0. + train_batch_size (int): Batch size for training the Q-net. Defaults to 32. + critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 1.0. + device (str): Identifier for the torch device. The policy will be moved to the specified device. If it is + None, the device will be set to "cpu" if cuda is unavailable and "cuda" otherwise. Defaults to None. + enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. Defaults to False. + """ + + def __init__( + self, + name: str, + get_q_critic_net_func: Callable[[], QNet], + reward_discount: float, + q_value_loss_cls: Callable = None, + policy: ContinuousRLPolicy = None, + random_overwrite: bool = False, + replay_memory_capacity: int = 10000, + num_epochs: int = 1, + update_target_every: int = 5, + soft_update_coef: float = 1.0, + train_batch_size: int = 32, + critic_loss_coef: float = 0.1, + device: str = None, + enable_data_parallelism: bool = False + ) -> None: + super(DDPG, self).__init__( + name=name, + device=device, + enable_data_parallelism=enable_data_parallelism, + train_batch_size=train_batch_size + ) + + self._get_q_critic_net_func = get_q_critic_net_func + + self._replay_memory_capacity = replay_memory_capacity + self._random_overwrite = random_overwrite + if policy is not None: + self.register_policy(policy) + + self._num_epochs = num_epochs + self._policy_version = self._target_policy_version = 0 + self._update_target_every = update_target_every + self._soft_update_coef = soft_update_coef + self._train_batch_size = train_batch_size + self._reward_discount = reward_discount + + self._critic_loss_coef = critic_loss_coef + self._q_value_loss_cls = q_value_loss_cls + + def _register_policy_impl(self, policy: ContinuousRLPolicy) -> None: + self._ops = DDPGTrainOps( + name="ops", device=self._device, get_q_critic_net_func=self._get_q_critic_net_func, + reward_discount=self._reward_discount, q_value_loss_cls=self._q_value_loss_cls, + soft_update_coef=self._soft_update_coef, critic_loss_coef=self._critic_loss_coef, + enable_data_parallelism=self._enable_data_parallelism + ) + self._ops.register_policy(policy) + + self._replay_memory = RandomReplayMemory( + capacity=self._replay_memory_capacity, state_dim=policy.state_dim, + action_dim=policy.action_dim, random_overwrite=self._random_overwrite + ) + + def train_step(self) -> None: + for _ in range(self._num_epochs): + self._ops.set_batch(self._get_batch()) + self._ops.update() + self._try_soft_update_target() + + def _try_soft_update_target(self) -> None: + self._policy_version += 1 + if self._policy_version - self._target_policy_version == self._update_target_every: + self._ops.soft_update_target() + self._target_policy_version = self._policy_version diff --git a/maro/rl_v3/policy_trainer/discrete_maddpg.py b/maro/rl_v3/policy_trainer/discrete_maddpg.py deleted file mode 100644 index fb34ac47b..000000000 --- a/maro/rl_v3/policy_trainer/discrete_maddpg.py +++ /dev/null @@ -1,224 +0,0 @@ -from typing import Callable, Dict, List, Optional - -import numpy as np -import torch - -from maro.rl_v3.model import MultiQNet -from maro.rl_v3.policy import DiscretePolicyGradient, RLPolicy -from maro.rl_v3.replay_memory import RandomMultiReplayMemory -from maro.rl_v3.utils import MultiTransitionBatch, ndarray_to_tensor -from maro.utils import clone - -from .abs_trainer import MultiTrainer - - -class DiscreteMADDPG(MultiTrainer): - def __init__( - self, - name: str, - reward_discount: float, - get_q_critic_net_func: Callable[[], MultiQNet], - policies: List[RLPolicy] = None, - replay_memory_capacity: int = 10000, - num_epoch: int = 10, - update_target_every: int = 5, - soft_update_coef: float = 0.5, - train_batch_size: int = 32, - q_value_loss_cls: Callable = None, - device: str = None, - critic_loss_coef: float = 1.0, - shared_critic: bool = False - - ) -> None: - super(DiscreteMADDPG, self).__init__(name, device) - - self._get_q_critic_net_func = get_q_critic_net_func - self._q_critic_nets: Optional[List[MultiQNet]] = None - self._target_q_critic_nets: Optional[List[MultiQNet]] = None - self._replay_memory_capacity = replay_memory_capacity - self._target_policies: Optional[List[DiscretePolicyGradient]] = None - if policies is not None: - self.register_policies(policies) - - self._num_epoch = num_epoch - self._update_target_every = update_target_every - self._policy_version = self._target_policy_version = 0 - self._soft_update_coef = soft_update_coef - self._train_batch_size = train_batch_size - self._reward_discount = reward_discount - self._critic_loss_coef = critic_loss_coef - self._shared_critic = shared_critic - - self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() - - def _record_impl(self, transition_batch: MultiTransitionBatch) -> None: - self._replay_memory.put(transition_batch) - - def _register_policies_impl(self, policies: List[RLPolicy]) -> None: - assert all(isinstance(policy, DiscretePolicyGradient) for policy in policies) - - self._policies = policies - self._policy_dict = { - policy.name: policy for policy in policies - } - if self._shared_critic: - q_critic_net = self._get_q_critic_net_func() - q_critic_net.to(self._device) - self._q_critic_nets = [q_critic_net for _ in range(self.num_policies)] - else: - self._q_critic_nets = [self._get_q_critic_net_func().to(self._device) for i in range(self.num_policies)] - - self._replay_memory = RandomMultiReplayMemory( - capacity=self._replay_memory_capacity, - state_dim=self._q_critic_nets[0].state_dim, - action_dims=[policy.action_dim for policy in policies], - agent_states_dims=[policy.state_dim for policy in policies] - ) - - self._target_policies: List[DiscretePolicyGradient] = [] - for policy in self._policies: - target_policy = clone(policy) - target_policy.set_name(f"target_{policy.name}") - self._target_policies.append(target_policy) - - for policy in self._target_policies: - policy.eval() - - self._target_q_critic_nets = [clone(net) for net in self._q_critic_nets] - for i in range(self.num_policies): - self._target_q_critic_nets[i].eval() - self._target_q_critic_nets[i].to(self._device) - - for policy in self._target_policies: - policy.to_device(self._device) - - def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch: - return self._replay_memory.sample(batch_size if batch_size is not None else self._train_batch_size) - - def train_step(self) -> None: - for _ in range(self._num_epoch): - train_batch = self._get_batch() - # iteratively update critic & actors - loss = self.get_batch_loss(train_batch, scope="critic") - for critic_net, critic_loss in zip(self._q_critic_nets, loss["critic_losses"]): - critic_net.step(critic_loss) - if self._shared_critic: - break # only update once for shared critic - - loss = self.get_batch_loss(train_batch, scope="actor") - for policy, actor_loss in zip(self._policies, loss["actor_losses"]): - policy.step(actor_loss) - - self._update_target_policy() - - def get_batch_loss(self, batch: MultiTransitionBatch, scope="all") -> Dict[str, List[torch.Tensor]]: - """Get loss with a batch of data. If scope is specified, return the expected loss of scope. - - Args: - batch (MultiTransitionBatch): The batch of multi-agent experience data. - scope (str): The expected scope to compute loss. Should be in ['all', 'critic', 'actor']. - - Returns: - loss_info (Dict[str, List[torch.Tensor]]): Loss of each scope. - """ - assert scope in ["all", "critic", "actor"], f'scope should in ["all", "critic", "actor"] but get {scope}.' - loss_info = dict() - if scope == "all" or scope == "critic": - critic_losses = self._get_critic_losses(batch) - loss_info["critic_losses"] = critic_losses - if scope == "all" or scope == "actor": - actor_losses = self._get_actor_losses(batch) - loss_info["actor_losses"] = actor_losses - return loss_info - - def _get_critic_losses(self, batch: MultiTransitionBatch) -> List[torch.Tensor]: - states = ndarray_to_tensor(batch.states, self._device) # x - next_states = ndarray_to_tensor(batch.next_states, self._device) # x' - agent_states = [ndarray_to_tensor(agent_state, self._device) for agent_state in batch.agent_states] # o - actions = [ndarray_to_tensor(action, self._device) for action in batch.actions] # a - rewards = ndarray_to_tensor(np.vstack([reward for reward in batch.rewards]), self._device) # r - terminals = ndarray_to_tensor(batch.terminals, self._device) # d - - with torch.no_grad(): - next_actions = [ - policy.get_actions_tensor(agent_state) # a' = miu'(o) - for policy, agent_state in zip(self._target_policies, agent_states) - ] - - critic_losses = [] - if self._shared_critic: - with torch.no_grad(): - next_q_values = self._target_q_critic_nets[0].q_values( - states=next_states, # x' - actions=next_actions # a' - ) # Q'(x', a') - # sum(rewards) for shard critic - target_q_values = (rewards.sum(0) + self._reward_discount * (1 - terminals.float()) * next_q_values) - q_values = self._q_critic_nets[0].q_values( - states=states, # x - actions=actions # a - ) # Q(x, a) - critic_loss = self._q_value_loss_func(q_values, target_q_values.detach()) * self._critic_loss_coef - critic_losses.append(critic_loss) - else: - for i in range(self.num_policies): - with torch.no_grad(): - next_q_values = self._target_q_critic_nets[i].q_values( - states=next_states, # x' - actions=next_actions) # a' - target_q_values = ( - rewards[i] + self._reward_discount * (1 - terminals.float()) * next_q_values) - q_values = self._q_critic_nets[i].q_values( - states=states, # x - actions=actions # a - ) # Q(x, a) - critic_loss = self._q_value_loss_func(q_values, target_q_values.detach()) * self._critic_loss_coef - critic_losses.append(critic_loss) - - return critic_losses - - def _get_actor_losses(self, batch: MultiTransitionBatch) -> List[torch.Tensor]: - for policy in self._policies: - policy.train() - - states = ndarray_to_tensor(batch.states, self._device) # x - agent_states = [ndarray_to_tensor(agent_state, self._device) for agent_state in batch.agent_states] # o - actions = [ndarray_to_tensor(action, self._device) for action in batch.actions] # a - - latest_actions = [] - latest_action_logps = [] - for policy, agent_state in zip(self._policies, agent_states): - assert isinstance(policy, DiscretePolicyGradient) - latest_actions.append(policy.get_actions_tensor(agent_state)) # a = miu(o) - latest_action_logps.append(policy.get_state_action_logps( - agent_state, # o - latest_actions[-1] # a - )) # log pi(a|o) - - actor_losses = [] - for i in range(self.num_policies): - # Update actor - self._q_critic_nets[i].freeze() - - action_backup = actions[i] - actions[i] = latest_actions[i] # Replace latest action - actor_loss = -(self._q_critic_nets[i].q_values( - states=states, # x - actions=actions # [a^j_1, ..., a_i, ..., a^j_N] - ) * latest_action_logps[i]).mean() # Q(x, a^j_1, ..., a_i, ..., a^j_N) - actor_losses.append(actor_loss) - - actions[i] = action_backup # Restore original action - self._q_critic_nets[i].unfreeze() - return actor_losses - - def _update_target_policy(self) -> None: - self._policy_version += 1 - if self._policy_version - self._target_policy_version == self._update_target_every: - for policy, target_policy in zip(self._policies, self._target_policies): - target_policy.soft_update(policy, self._soft_update_coef) - for critic, target_critic in zip(self._q_critic_nets, self._target_q_critic_nets): - target_critic.soft_update(critic, self._soft_update_coef) - if self._shared_critic: - break # only update once for shared critic - self._target_policy_version = self._policy_version diff --git a/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py b/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py new file mode 100644 index 000000000..8f2f7928c --- /dev/null +++ b/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py @@ -0,0 +1,455 @@ +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch + +from maro.rl_v3.model import MultiQNet +from maro.rl_v3.policy import DiscretePolicyGradient, RLPolicy +from maro.rl_v3.policy_trainer import MultiTrainer +from maro.rl_v3.policy_trainer.abs_train_ops import MultiTrainOps +from maro.rl_v3.replay_memory import RandomMultiReplayMemory +from maro.rl_v3.utils import MultiTransitionBatch, ndarray_to_tensor +from maro.utils import clone + + +class DiscreteMADDPGTrainOps(MultiTrainOps): + """The discrete variant of MADDPG algorithm. + Args: + name (str): Name of the worker. + device (torch.device): Which device to use. + reward_discount (float): The discount factor of feature reward. + get_q_critic_net_func (Callable[[], MultiQNet): Function to get Q critic net. + shared_critic (bool): Whether to share critic for actors. Defaults to False. + critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 1.0. + soft_update_coef (float): Soft update coefficient, e.g., target_model = (soft_update_coef) * eval_model + + (1-soft_update_coef) * target_model. Defaults to 1.0. + update_target_every (int): Number of training rounds between policy target model updates. Defaults to 5. + q_value_loss_func (Callable): The loss function provided by torch.nn or a custom loss class for the + Q-value loss. Defaults to None. + enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. Defaults to False. + + Reference: + Paper: http://papers.nips.cc/paper/by-source-2017-3193 + Code: https://github.com/openai/maddpg + """ + def __init__( + self, + name: str, + device: torch.device, + reward_discount: float, + get_q_critic_net_func: Callable[[], MultiQNet], + shared_critic: bool = False, + critic_loss_coef: float = 1.0, + soft_update_coef: float = 0.5, + update_target_every: int = 5, + q_value_loss_func: Callable = None, + enable_data_parallelism: bool = False + ) -> None: + super(DiscreteMADDPGTrainOps, self).__init__(name, device, enable_data_parallelism) + + # Actor + self._target_policies: Dict[int, DiscretePolicyGradient] = {} + + # Critic + self._get_q_critic_net_func = get_q_critic_net_func + self._q_critic_nets: Dict[int, MultiQNet] = {} + self._target_q_critic_nets: Dict[int, MultiQNet] = {} + + # + self._shared_critic = shared_critic + + self._reward_discount = reward_discount + self._critic_loss_coef = critic_loss_coef + self._q_value_loss_func = q_value_loss_func + self._update_target_every = update_target_every + self._soft_update_coef = soft_update_coef + + def _register_policies_impl(self, policy_dict: Dict[int, RLPolicy]) -> None: + # Actors + self._policies: Dict[int, DiscretePolicyGradient] = {} + self._target_policies: Dict[int, DiscretePolicyGradient] = {} + for i, policy in policy_dict.items(): + assert isinstance(policy, DiscretePolicyGradient) + target_policy: DiscretePolicyGradient = clone(policy) + target_policy.set_name(f"target_{policy.name}") + target_policy.to_device(self._device) + target_policy.eval() + + self._policies[i] = policy + self._target_policies[i] = target_policy + + # Critic + self._q_critic_nets: Dict[int, MultiQNet] = {} + self._target_q_critic_nets: Dict[int, MultiQNet] = {} + indexes = [0] if self._shared_critic else self._indexes + for i in indexes: + q_critic_net = self._get_q_critic_net_func() + q_critic_net.to(self._device) + target_q_critic_net = clone(q_critic_net) + target_q_critic_net.to(self._device) + target_q_critic_net.eval() + self._q_critic_nets[i] = q_critic_net + self._target_q_critic_nets[i] = target_q_critic_net + + def get_target_action_dict(self) -> Dict[int, torch.Tensor]: + agent_state_dict = { + i: ndarray_to_tensor(self._batch.agent_states[i], self._device) + for i in self._indexes + } # o + with torch.no_grad(): + action_dict = { + i: policy.get_actions_tensor(agent_state_dict[i]) + for i, policy in self._target_policies.items() + } + return action_dict + + def get_latest_action_dict(self) -> Tuple[dict, dict]: + agent_state_dict = { + i: ndarray_to_tensor(self._batch.agent_states[i], self._device) + for i in self._indexes + } # o + + latest_actions = {} + latest_action_logps = {} + for i, policy in self._policies.items(): + policy.train() + action = policy.get_actions_tensor(agent_state_dict[i]) + logps = policy.get_state_action_logps(agent_state_dict[i], action) + latest_actions[i] = action + latest_action_logps[i] = logps + + return latest_actions, latest_action_logps + + def get_ops_state_dict(self, scope: str = "all") -> dict: + ret_dict = {} + + if scope in ("all", "actor"): + ret_dict["policy_state"] = {i: self._policies[i].get_policy_state() for i in self._indexes} + ret_dict["target_policy_state"] = {i: self._target_policies[i].get_policy_state() for i in self._indexes} + if scope in ("all", "critic"): + indexes = [0] if self._shared_critic else self._indexes + ret_dict["critic_state"] = {i: self._q_critic_nets[i].get_net_state() for i in indexes} + ret_dict["target_critic_state"] = {i: self._target_q_critic_nets[i].get_net_state() for i in indexes} + + return ret_dict + + def set_ops_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None: + if scope in ("all", "actor"): + for i in self._indexes: + self._policies[i].set_policy_state(ops_state_dict["policy_state"][i]) + self._target_policies[i].set_policy_state(ops_state_dict["target_policy_state"][i]) + if scope in ("all", "critic"): + indexes = [0] if self._shared_critic else self._indexes + for i in indexes: + self._q_critic_nets[i].set_net_state(ops_state_dict["critic_state"][i]) + self._target_q_critic_nets[i].set_net_state(ops_state_dict["target_critic_state"][i]) + + def _get_critic_grad( + self, + batch: MultiTransitionBatch, + next_actions: List[torch.Tensor] + ) -> Dict[int, Dict[str, torch.Tensor]]: + states = ndarray_to_tensor(batch.states, self._device) # x + actions = [ndarray_to_tensor(action, self._device) for action in batch.actions] # a + + next_states = ndarray_to_tensor(batch.next_states, self._device) # x' + rewards = ndarray_to_tensor(np.vstack([reward for reward in batch.rewards]), self._device) # r + terminals = ndarray_to_tensor(batch.terminals, self._device) # d + + for net in self._q_critic_nets.values(): + net.train() + + critic_loss_dict = {} + indexes = [0] if self._shared_critic else self._indexes + for i in indexes: + q_net = self._q_critic_nets[i] + target_q_net = self._target_q_critic_nets[i] + with torch.no_grad(): + next_q_values = target_q_net.q_values( + states=next_states, # x' + actions=next_actions + ) # a' + target_q_values = (rewards[i] + self._reward_discount * (1 - terminals.float()) * next_q_values) + q_values = q_net.q_values( + states=states, # x + actions=actions # a + ) # Q(x, a) + critic_loss = self._q_value_loss_func(q_values, target_q_values.detach()) * self._critic_loss_coef + critic_loss_dict[i] = critic_loss + + return { + i: self._q_critic_nets[i].get_gradients(critic_loss_dict[i]) + for i in indexes + } + + def _get_actor_grad( + self, + batch: MultiTransitionBatch, + latest_actions: List[torch.Tensor], + latest_action_logps: List[torch.Tensor] + ) -> Dict[int, Dict[str, torch.Tensor]]: + states = ndarray_to_tensor(batch.states, self._device) # x + actions = [ndarray_to_tensor(action, self._device) for action in batch.actions] # a + + for policy in self._policies.values(): + policy.train() + + actor_loss_dict = {} + for i in self._indexes: + q_net = self._q_critic_nets[i] + q_net.freeze() + + action_backup = actions[i] + actions[i] = latest_actions[i] # Replace latest action + actor_loss = -(q_net.q_values( + states=states, # x + actions=actions # [a^j_1, ..., a_i, ..., a^j_N] + ) * latest_action_logps[i]).mean() # Q(x, a^j_1, ..., a_i, ..., a^j_N) + actor_loss_dict[i] = actor_loss + + actions[i] = action_backup # Restore original action + q_net.unfreeze() + + return { + i: self._policies[i].get_gradients(actor_loss_dict[i]) + for i in self._indexes + } + + def get_batch_grad( + self, + batch: MultiTransitionBatch, + tensor_dict: Dict[str, object] = None, + scope: str = "all" + ) -> Dict[str, Dict[int, Dict[str, torch.Tensor]]]: + assert scope in ("all", "actor", "critic"), \ + f"Unrecognized scope {scope}. Excepting 'all', 'actor', or 'critic'." + + if tensor_dict is None: + tensor_dict = {} + + grad_dict = {} + if scope in ("all", "critic"): + assert "next_actions" in tensor_dict + next_actions = tensor_dict["next_actions"] + assert isinstance(next_actions, list) + assert all(isinstance(action, torch.Tensor) for action in next_actions) + + grad_dict["critic_grads"] = self._get_critic_grad(batch, next_actions) + if scope in ("all", "actor"): + assert "latest_actions" in tensor_dict + assert "latest_action_logps" in tensor_dict + latest_actions = tensor_dict["latest_actions"] + latest_action_logps = tensor_dict["latest_action_logps"] + assert isinstance(latest_actions, list) and isinstance(latest_action_logps, list) + assert all(isinstance(action, torch.Tensor) for action in latest_actions) + assert all(isinstance(logps, torch.Tensor) for logps in latest_action_logps) + + grad_dict["actor_grads"] = self._get_actor_grad(batch, latest_actions, latest_action_logps) + + return grad_dict + + def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_ops: int) -> List[Dict[str, object]]: + raise NotImplementedError + + def _dispatch_batch(self, batch: MultiTransitionBatch, num_ops: int) -> List[MultiTransitionBatch]: + batch_size = batch.states.shape[0] + assert batch_size >= num_ops, \ + f"Batch size should be greater than or equal to num_ops, but got {batch_size} and {num_ops}." + sub_batch_indexes = [range(batch_size)[i::num_ops] for i in range(num_ops)] + sub_batches = [MultiTransitionBatch( + policy_names=[], + states=batch.states[indexes], + actions=[action[indexes] for action in batch.actions], + rewards=[reward[indexes] for reward in batch.rewards], + terminals=batch.terminals[indexes], + next_states=batch.next_states[indexes], + agent_states=[state[indexes] for state in batch.agent_states], + next_agent_states=[state[indexes] for state in batch.next_agent_states] + ) for indexes in sub_batch_indexes] + return sub_batches + + def update_critics(self, next_actions: List[torch.Tensor]) -> None: + grads = self._get_batch_grad( + self._batch, + tensor_dict={"next_actions": next_actions}, + scope="critic" + ) + + for i, grad in grads["critic_grads"].items(): + self._q_critic_nets[i].train() + self._q_critic_nets[i].apply_gradients(grad) + + def update_actors(self, latest_actions: List[torch.Tensor], latest_action_logps: List[torch.Tensor]) -> None: + grads = self._get_batch_grad( + self._batch, + tensor_dict={ + "latest_actions": latest_actions, + "latest_action_logps": latest_action_logps + }, + scope="actor" + ) + + for i, grad in grads["actor_grads"].items(): + self._policies[i].train() + self._policies[i].apply_gradients(grad) + + def soft_update_target(self) -> None: + for i in self._indexes: + self._target_policies[i].soft_update(self._policies[i], self._soft_update_coef) + + indexes = [0] if self._shared_critic else self._indexes + for i in indexes: + self._target_q_critic_nets[i].soft_update(self._q_critic_nets[i], self._soft_update_coef) + + +class DistributedDiscreteMADDPG(MultiTrainer): + def __init__( + self, + name: str, + reward_discount: float, + get_q_critic_net_func: Callable[[], MultiQNet], + group_size: int = 1, + policies: List[RLPolicy] = None, + replay_memory_capacity: int = 10000, + num_epoch: int = 10, + update_target_every: int = 5, + soft_update_coef: float = 0.5, + train_batch_size: int = 32, + q_value_loss_cls: Callable = None, + device: str = None, + critic_loss_coef: float = 1.0, + shared_critic: bool = False, + enable_data_parallelism: bool = False + ) -> None: + super(DistributedDiscreteMADDPG, self).__init__( + name=name, + device=device, + enable_data_parallelism=enable_data_parallelism, + train_batch_size=train_batch_size + ) + + self._get_q_critic_net_func = get_q_critic_net_func + self._critic_ops: Optional[DiscreteMADDPGTrainOps] = None + self._group_size = group_size + self._replay_memory_capacity = replay_memory_capacity + self._target_policies: List[DiscretePolicyGradient] = [] + self._shared_critic = shared_critic + if policies is not None: + self.register_policies(policies) + + self._num_epoch = num_epoch + self._update_target_every = update_target_every + self._policy_version = self._target_policy_version = 0 + self._soft_update_coef = soft_update_coef + self._reward_discount = reward_discount + self._critic_loss_coef = critic_loss_coef + + self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() + + def _register_policies_impl(self, policies: List[RLPolicy]) -> None: + if self._shared_critic: + self._critic_ops = DiscreteMADDPGTrainOps( + name="critic_ops", + reward_discount=self._reward_discount, get_q_critic_net_func=self._get_q_critic_net_func, + shared_critic=self._shared_critic, device=self._device, + enable_data_parallelism=self._enable_data_parallelism, + critic_loss_coef=self._critic_loss_coef, soft_update_coef=self._soft_update_coef, + update_target_every=self._update_target_every, q_value_loss_func=self._q_value_loss_func + ) + self._critic_ops.register_policies({}) # Register with empty policy dict to init the critic net + + self._ops_list: List[DiscreteMADDPGTrainOps] = [] + self._ops_indexes: List[List[int]] = [] + cursor = 0 + while cursor < self.num_policies: + cursor_end = min(cursor + self._group_size, self.num_policies) + indexes = list(range(cursor, cursor_end)) + + ops = DiscreteMADDPGTrainOps( + name=f"actor_ops__{cursor}_{cursor_end - 1}", + reward_discount=self._reward_discount, get_q_critic_net_func=self._get_q_critic_net_func, + shared_critic=self._shared_critic, device=self._device, + enable_data_parallelism=self._enable_data_parallelism, + critic_loss_coef=self._critic_loss_coef, soft_update_coef=self._soft_update_coef, + update_target_every=self._update_target_every, q_value_loss_func=self._q_value_loss_func + ) + ops.register_policies({i: policies[i] for i in indexes}) + + cursor = cursor_end + self._ops_list.append(ops) + self._ops_indexes.append(indexes) + + # Replay + self._replay_memory = RandomMultiReplayMemory( + capacity=self._replay_memory_capacity, + state_dim=self._get_q_critic_net_func().state_dim, + action_dims=[policy.action_dim for policy in policies], + agent_states_dims=[policy.state_dim for policy in policies] + ) + + def train_step(self) -> None: + for _ in range(self._num_epoch): + self._improve(self._get_batch()) + + def _improve(self, batch: MultiTransitionBatch) -> None: + for ops in self._ops_list: + ops.set_batch(batch) + + # Collect next actions + next_action_dict: Dict[int, torch.Tensor] = {} + for ops in self._ops_list: + next_action_dict.update(ops.get_target_action_dict()) + next_actions = [next_action_dict[i] for i in range(self.num_policies)] + + # Update critic + if self._shared_critic: + self._critic_ops.set_batch(batch) + self._critic_ops.update_critics(next_actions=next_actions) + critic_state_dict = self._critic_ops.get_ops_state_dict(scope="critic") + + # Sync latest critic to ops + for ops in self._ops_list: + ops.set_ops_state_dict(critic_state_dict, scope="critic") + else: + for ops in self._ops_list: + ops.update_critics(next_actions=next_actions) + + # Update actor + latest_actions_dict = {} + latest_action_logps_dict = {} + for ops in self._ops_list: + cur_action_dict, cur_logps_dict = ops.get_latest_action_dict() + latest_actions_dict.update(cur_action_dict) + latest_action_logps_dict.update(cur_logps_dict) + latest_actions = [latest_actions_dict[i] for i in range(self.num_policies)] + latest_action_logps = [latest_action_logps_dict[i] for i in range(self.num_policies)] + + for ops in self._ops_list: + ops.update_actors(latest_actions, latest_action_logps) + + # Update version + self._try_soft_update_target() + + def _try_soft_update_target(self) -> None: + self._policy_version += 1 + if self._policy_version - self._target_policy_version == self._update_target_every: + if self._shared_critic: + self._critic_ops.soft_update_target() + + for ops in self._ops_list: + ops.soft_update_target() + + self._target_policy_version = self._policy_version + + def get_policy_state_dict(self) -> Dict[str, object]: + policy_state_dict = {} + for ops in self._ops_list: + policy_state_dict.update(ops.get_policy_state_dict()) + return {name: policy_state_dict[i] for i, name in enumerate(self._policy_names)} + + def set_policy_state_dict(self, policy_state_dict: Dict[str, object]) -> None: + assert len(policy_state_dict) == self.num_policies + + for ops, indexes in zip(self._ops_list, self._ops_indexes): + cur_dict = {i: policy_state_dict[self._policy_names[i]] for i in indexes} + ops.set_policy_state_dict(cur_dict) diff --git a/maro/rl_v3/policy_trainer/dqn.py b/maro/rl_v3/policy_trainer/dqn.py index 68e7026a5..cf433cc0a 100644 --- a/maro/rl_v3/policy_trainer/dqn.py +++ b/maro/rl_v3/policy_trainer/dqn.py @@ -1,67 +1,50 @@ -from typing import Optional +from typing import Dict, List import torch -from maro.rl_v3.policy import ValueBasedPolicy +from maro.rl_v3.policy import RLPolicy, ValueBasedPolicy from maro.rl_v3.replay_memory import RandomReplayMemory from maro.rl_v3.utils import TransitionBatch, ndarray_to_tensor from maro.utils import clone + +from .abs_train_ops import SingleTrainOps from .abs_trainer import SingleTrainer -class DQN(SingleTrainer): - """ - TODO: docs. - """ +class DQNTrainOps(SingleTrainOps): def __init__( self, name: str, - policy: ValueBasedPolicy = None, - replay_memory_capacity: int = 100000, - train_batch_size: int = 128, - num_epochs: int = 1, + device: torch.device, reward_discount: float = 0.9, - update_target_every: int = 5, soft_update_coef: float = 0.1, double: bool = False, - random_overwrite: bool = False, - device: str = None + enable_data_parallelism: bool = False ) -> None: - super(DQN, self).__init__(name, device) + super(DQNTrainOps, self).__init__(name, device, enable_data_parallelism) - self._policy: Optional[ValueBasedPolicy] = None - self._target_policy: Optional[ValueBasedPolicy] = None - self._replay_memory_capacity = replay_memory_capacity - self._random_overwrite = random_overwrite - if policy is not None: - self.register_policy(policy) - - self._train_batch_size = train_batch_size - self._num_epochs = num_epochs self._reward_discount = reward_discount - - self._policy_ver = self._target_policy_ver = 0 - self._update_target_every = update_target_every self._soft_update_coef = soft_update_coef self._double = double - self._loss_func = torch.nn.MSELoss() - def _record_impl(self, policy_name: str, transition_batch: TransitionBatch) -> None: - self._replay_memory.put(transition_batch) + def _register_policy_impl(self, policy: RLPolicy) -> None: + assert isinstance(policy, ValueBasedPolicy) - def _get_batch(self, batch_size: int = None) -> TransitionBatch: - return self._replay_memory.sample(batch_size if batch_size is not None else self._train_batch_size) + self._policy = policy + self._target_policy: ValueBasedPolicy = clone(policy) + self._target_policy.set_name(f"target_{policy.name}") + self._target_policy.eval() + self._target_policy.to_device(self._device) - def train_step(self) -> None: - for _ in range(self._num_epochs): - self._improve(self._get_batch()) - self._policy_ver += 1 - if self._policy_ver - self._target_policy_ver == self._update_target_every: - self._target_policy.soft_update(self._policy, self._soft_update_coef) - self._target_policy_ver = self._policy_ver + def get_batch_grad( + self, + batch: TransitionBatch, + tensor_dict: Dict[str, object] = None, + scope: str = "all" + ) -> Dict[str, Dict[str, torch.Tensor]]: + assert scope == "all", f"Unrecognized scope {scope}. Excepting 'all'." - def _improve(self, batch: TransitionBatch) -> None: self._policy.train() states = ndarray_to_tensor(batch.states, self._device) next_states = ndarray_to_tensor(batch.next_states, self._device) @@ -81,20 +64,121 @@ def _improve(self, batch: TransitionBatch) -> None: target_q_values = (rewards + self._reward_discount * (1 - terminals) * next_q_values).detach() q_values = self._policy.q_values_tensor(states, actions) - loss = self._loss_func(q_values, target_q_values) + loss: torch.Tensor = self._loss_func(q_values, target_q_values) - self._policy.step(loss) + return {"grad": self._policy.get_gradients(loss)} - def _register_policy_impl(self, policy: ValueBasedPolicy) -> None: - assert isinstance(policy, ValueBasedPolicy) + def _dispatch_batch(self, batch: TransitionBatch, num_ops: int) -> List[TransitionBatch]: + raise NotImplementedError - self._policy = policy - self._target_policy: ValueBasedPolicy = clone(policy) - self._target_policy.set_name(f"target_{policy.name}") - self._target_policy.eval() - self._target_policy.to_device(self._device) + def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_ops: int) -> List[Dict[str, object]]: + raise NotImplementedError + + def get_ops_state_dict(self, scope: str = "all") -> dict: + return { + "policy_state": self._policy.get_policy_state(), + "target_policy_state": self._target_policy.get_policy_state() + } + + def set_ops_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None: + self._policy.set_policy_state(ops_state_dict["policy_state"]) + self._target_policy.set_policy_state(ops_state_dict["target_policy_state"]) + + def update(self) -> None: + grad_dict = self._get_batch_grad(self._batch) + + self._policy.train() + self._policy.apply_gradients(grad_dict["grad"]) + + def soft_update_target(self) -> None: + self._target_policy.soft_update(self._policy, self._soft_update_coef) + + +class DQN(SingleTrainer): + """The Deep-Q-Networks algorithm. + + See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details. + + Args: + name (str): Unique identifier for the policy. + policy (ValueBasedPolicy): The policy to be trained. + replay_memory_capacity (int): Capacity of the replay memory. Defaults to 100000. + train_batch_size (int): Batch size for training the Q-net. Defaults to 128. + num_epochs (int): Number of training epochs per call to ``learn``. Defaults to 1. + reward_discount (float): Reward decay as defined in standard RL terminology. Defaults to 0.9. + update_target_every (int): Number of gradient steps between target model updates. Defaults to 5. + soft_update_coef (float): Soft update coefficient, e.g., + target_model = (soft_update_coef) * eval_model + (1-soft_update_coef) * target_model. + Defaults to 0.1. + double (bool): If True, the next Q values will be computed according to the double DQN algorithm, + i.e., q_next = Q_target(s, argmax(Q_eval(s, a))). Otherwise, q_next = max(Q_target(s, a)). + See https://arxiv.org/pdf/1509.06461.pdf for details. Defaults to False. + random_overwrite (bool): This specifies overwrite behavior when the replay memory capacity is reached. If True, + overwrite positions will be selected randomly. Otherwise, overwrites will occur sequentially with + wrap-around. Defaults to False. + device (str): Identifier for the torch device. The policy will be moved to the specified device. If it is + None, the device will be set to "cpu" if cuda is unavailable and "cuda" otherwise. Defaults to None. + enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. Defaults to False. + """ + def __init__( + self, + name: str, + policy: ValueBasedPolicy = None, + replay_memory_capacity: int = 100000, + train_batch_size: int = 128, + num_epochs: int = 1, + reward_discount: float = 0.9, + update_target_every: int = 5, + soft_update_coef: float = 0.1, + double: bool = False, + random_overwrite: bool = False, + device: str = None, + enable_data_parallelism: bool = False + ) -> None: + super(DQN, self).__init__( + name=name, + device=device, + enable_data_parallelism=enable_data_parallelism, + train_batch_size=train_batch_size + ) + + self._replay_memory_capacity = replay_memory_capacity + self._random_overwrite = random_overwrite + if policy is not None: + self.register_policy(policy) + + self._train_batch_size = train_batch_size + self._num_epochs = num_epochs + self._reward_discount = reward_discount + + self._update_target_every = update_target_every + self._soft_update_coef = soft_update_coef + self._double = double + + self._loss_func = torch.nn.MSELoss() + self._policy_version = self._target_policy_version = 0 + + def train_step(self) -> None: + for _ in range(self._num_epochs): + self._ops.set_batch(self._get_batch()) + self._ops.update() + self._try_soft_update_target() + + def _register_policy_impl(self, policy: ValueBasedPolicy) -> None: + self._ops = DQNTrainOps( + name="ops", device=self._device, reward_discount=self._reward_discount, + soft_update_coef=self._soft_update_coef, double=self._double, + enable_data_parallelism=self._enable_data_parallelism + ) + self._ops.register_policy(policy) self._replay_memory = RandomReplayMemory( capacity=self._replay_memory_capacity, state_dim=policy.state_dim, action_dim=policy.action_dim, random_overwrite=self._random_overwrite ) + + def _try_soft_update_target(self) -> None: + self._policy_version += 1 + if self._policy_version - self._target_policy_version == self._update_target_every: + self._ops.soft_update_target() + self._target_policy_version = self._policy_version diff --git a/maro/rl_v3/policy_trainer/maac.py b/maro/rl_v3/policy_trainer/maac.py deleted file mode 100644 index ae4827178..000000000 --- a/maro/rl_v3/policy_trainer/maac.py +++ /dev/null @@ -1,158 +0,0 @@ -from typing import Callable, List, Optional - -import torch - -from maro.rl_v3.model import MultiQNet -from maro.rl_v3.policy import DiscretePolicyGradient, RLPolicy -from maro.rl_v3.replay_memory import RandomMultiReplayMemory -from maro.rl_v3.utils import MultiTransitionBatch, ndarray_to_tensor -from maro.utils import clone -from .abs_trainer import MultiTrainer - - -class DiscreteMultiActorCritic(MultiTrainer): - def __init__( - self, - name: str, - reward_discount: float, - get_v_critic_net_func: Callable[[], MultiQNet], - policies: List[RLPolicy] = None, - replay_memory_capacity: int = 10000, - num_epoch: int = 10, - update_target_every: int = 5, - soft_update_coef: float = 1.0, - train_batch_size: int = 32, - q_value_loss_cls: Callable = None, - device: str = None, - critic_loss_coef: float = 0.1 - - ) -> None: - super(DiscreteMultiActorCritic, self).__init__(name, device) - - self._get_v_critic_net_func = get_v_critic_net_func - self._q_critic_net: Optional[MultiQNet] = None - self._target_q_critic_net: Optional[MultiQNet] = None - self._replay_memory_capacity = replay_memory_capacity - self._target_policies: Optional[List[DiscretePolicyGradient]] = None - if policies is not None: - self.register_policies(policies) - - self._num_epoch = num_epoch - self._update_target_every = update_target_every - self._policy_ver = self._target_policy_ver = 0 - self._soft_update_coef = soft_update_coef - self._train_batch_size = train_batch_size - self._reward_discount = reward_discount - self._critic_loss_coef = critic_loss_coef - - self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() - - def _record_impl(self, transition_batch: MultiTransitionBatch) -> None: - self._replay_memory.put(transition_batch) - - def _register_policies_impl(self, policies: List[RLPolicy]) -> None: - assert all(isinstance(policy, DiscretePolicyGradient) for policy in policies) - - self._policies = policies - self._policy_dict = { - policy.name: policy for policy in policies - } - self._q_critic_net = self._get_v_critic_net_func() - - self._replay_memory = RandomMultiReplayMemory( - capacity=self._replay_memory_capacity, - state_dim=self._q_critic_net.state_dim, - action_dims=[policy.action_dim for policy in policies], - agent_states_dims=[policy.state_dim for policy in policies] - ) - - self._target_policies: List[DiscretePolicyGradient] = [] - for policy in self._policies: - target_policy = clone(policy) - target_policy.set_name(f"target_{policy.name}") - self._target_policies.append(target_policy) - - for policy in self._target_policies: - policy.eval() - self._target_q_critic_net: MultiQNet = clone(self._q_critic_net) - self._target_q_critic_net.eval() - - for policy in self._target_policies: - policy.to_device(self._device) - self._q_critic_net.to(self._device) - self._target_q_critic_net.to(self._device) - - def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch: - return self._replay_memory.sample(batch_size if batch_size is not None else self._train_batch_size) - - def train_step(self) -> None: - for _ in range(self._num_epoch): - self._improve(self._get_batch()) - self._update_target_policy() - - def _improve(self, batch: MultiTransitionBatch) -> None: - """ - References: https://arxiv.org/pdf/1706.02275.pdf - """ - for policy in self._policies: - policy.train() - - states = ndarray_to_tensor(batch.states, self._device) # x - next_states = ndarray_to_tensor(batch.next_states, self._device) # x' - agent_states = [ndarray_to_tensor(agent_state, self._device) for agent_state in batch.agent_states] # o - actions = [ndarray_to_tensor(action, self._device) for action in batch.actions] # a - rewards = [ndarray_to_tensor(reward, self._device) for reward in batch.rewards] # r - terminals = ndarray_to_tensor(batch.terminals, self._device) # d - - with torch.no_grad(): - next_actions = [ - policy.get_actions_tensor(agent_state) # a' = miu'(o) - for policy, agent_state in zip(self._target_policies, agent_states) - ] - next_q_values = self._target_q_critic_net.q_values( - states=next_states, # x' - actions=next_actions # a' - ) # Q'(x', a') - - latest_actions = [] - latest_action_logps = [] - for policy, agent_state in zip(self._policies, agent_states): - assert isinstance(policy, DiscretePolicyGradient) - latest_actions.append(policy.get_actions_tensor(agent_state)) # a = miu(o) - latest_action_logps.append(policy.get_state_action_logps( - agent_state, # o - latest_actions[-1] # a - )) # log pi(a|o) - - for i in range(len(self._policies)): - # Update critic - # y = r + gamma * (1 - d) * Q' - target_q_values = (rewards[i] + self._reward_discount * (1 - terminals.float()) * next_q_values).detach() - q_values = self._q_critic_net.q_values( - states=states, # x - actions=actions # a - ) # Q(x, a) - critic_loss = self._q_value_loss_func(q_values, target_q_values) # MSE(Q(x, a), Q'(x', a')) - self._q_critic_net.step(critic_loss * self._critic_loss_coef) - - # Update actor - self._q_critic_net.freeze() - - action_backup = actions[i] - actions[i] = latest_actions[i] # Replace latest action - policy_loss = -(self._q_critic_net.q_values( - states=states, # x - actions=actions # [a^j_1, ..., a_i, ..., a^j_N] - ) * latest_action_logps[i]).mean() # Q(x, a^j_1, ..., a_i, ..., a^j_N) - self._policies[i].step(policy_loss) - - actions[i] = action_backup # Restore original action - self._q_critic_net.unfreeze() - - def _update_target_policy(self) -> None: - self._policy_ver += 1 - if self._policy_ver - self._target_policy_ver == self._update_target_every: - for policy, target_policy in zip(self._policies, self._target_policies): - target_policy.soft_update(policy, self._soft_update_coef) - self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef) - self._target_policy_ver = self._policy_ver diff --git a/maro/rl_v3/replay_memory/__init__.py b/maro/rl_v3/replay_memory/__init__.py index b8bacb5f3..d8029cced 100644 --- a/maro/rl_v3/replay_memory/__init__.py +++ b/maro/rl_v3/replay_memory/__init__.py @@ -1,6 +1,6 @@ from .replay_memory import ( - FIFOMultiReplayMemory, FIFOReplayMemory, MultiReplayMemory, MultiTransitionBatch, - RandomMultiReplayMemory, RandomReplayMemory, ReplayMemory, TransitionBatch + FIFOMultiReplayMemory, FIFOReplayMemory, MultiReplayMemory, MultiTransitionBatch, RandomMultiReplayMemory, + RandomReplayMemory, ReplayMemory, TransitionBatch ) __all__ = [ diff --git a/maro/rl_v3/replay_memory/replay_memory.py b/maro/rl_v3/replay_memory/replay_memory.py index 12f25c887..1c3104419 100644 --- a/maro/rl_v3/replay_memory/replay_memory.py +++ b/maro/rl_v3/replay_memory/replay_memory.py @@ -3,7 +3,7 @@ import numpy as np -from maro.rl_v3.utils import MultiTransitionBatch, SHAPE_CHECK_FLAG, TransitionBatch, match_shape +from maro.rl_v3.utils import SHAPE_CHECK_FLAG, MultiTransitionBatch, TransitionBatch, match_shape class AbsIndexScheduler(object, metaclass=ABCMeta): diff --git a/maro/rl_v3/tmp_example_multi/env_sampler.py b/maro/rl_v3/tmp_example_multi/env_sampler.py index 72a37aa52..d99bb16fa 100644 --- a/maro/rl_v3/tmp_example_multi/env_sampler.py +++ b/maro/rl_v3/tmp_example_multi/env_sampler.py @@ -5,6 +5,7 @@ from maro.rl_v3.learning import AbsEnvSampler, CacheElement, SimpleAgentWrapper from maro.simulator import Env from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent + from .config import ( action_shaping_conf, env_conf, port_attributes, reward_shaping_conf, state_shaping_conf, vessel_attributes ) diff --git a/maro/rl_v3/tmp_example_multi/main.py b/maro/rl_v3/tmp_example_multi/main.py index 61aad9c59..aedabd650 100644 --- a/maro/rl_v3/tmp_example_multi/main.py +++ b/maro/rl_v3/tmp_example_multi/main.py @@ -1,6 +1,7 @@ from maro.rl_v3 import run_workflow_centralized_mode from maro.rl_v3.learning import SimpleAgentWrapper, SimpleTrainerManager from maro.simulator import Env + from .callbacks import cim_post_collect, cim_post_evaluate from .config import algorithm, env_conf, running_mode from .env_sampler import CIMEnvSampler diff --git a/maro/rl_v3/tmp_example_multi/nets.py b/maro/rl_v3/tmp_example_multi/nets.py index 971f70f35..51ccc837c 100644 --- a/maro/rl_v3/tmp_example_multi/nets.py +++ b/maro/rl_v3/tmp_example_multi/nets.py @@ -48,16 +48,16 @@ def freeze(self) -> None: def unfreeze(self) -> None: self.unfreeze_all_parameters() - def step(self, loss: torch.Tensor) -> None: - self._actor_optim.zero_grad() - loss.backward() - self._actor_optim.step() - def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: self._actor_optim.zero_grad() loss.backward() return {name: param.grad for name, param in self.named_parameters()} + def apply_gradients(self, grad: dict) -> None: + for name, param in self.named_parameters(): + param.grad = grad[name] + self._actor_optim.step() + def get_net_state(self) -> dict: return { "network": self.state_dict(), @@ -81,16 +81,16 @@ def __init__(self) -> None: def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor: return self._critic(torch.cat([states] + actions, dim=1)).squeeze(-1) - def step(self, loss: torch.Tensor) -> None: - self._critic_optim.zero_grad() - loss.backward() - self._critic_optim.step() - def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: self._critic_optim.zero_grad() loss.backward() return {name: param.grad for name, param in self.named_parameters()} + def apply_gradients(self, grad: dict) -> None: + for name, param in self.named_parameters(): + param.grad = grad[name] + self._critic_optim.step() + def get_net_state(self) -> dict: return { "network": self.state_dict(), diff --git a/maro/rl_v3/tmp_example_multi/policies.py b/maro/rl_v3/tmp_example_multi/policies.py index 7eb768bcd..0ca7ec7c8 100644 --- a/maro/rl_v3/tmp_example_multi/policies.py +++ b/maro/rl_v3/tmp_example_multi/policies.py @@ -1,6 +1,7 @@ from maro.rl_v3.policy import DiscretePolicyGradient -from maro.rl_v3.policy_trainer import DiscreteMultiActorCritic, DiscreteMADDPG +from maro.rl_v3.policy_trainer import DistributedDiscreteMADDPG from maro.rl_v3.workflow import preprocess_get_policy_func_dict + from .config import algorithm, running_mode from .nets import MyActorNet, MyMultiCriticNet @@ -10,23 +11,13 @@ } # ##################################################################################################################### -if algorithm == "maac": - get_policy_func_dict = { - f"{algorithm}.{i}": lambda name: DiscretePolicyGradient( - name=name, policy_net=MyActorNet()) for i in range(4) - } - get_trainer_func_dict = { - f"{algorithm}.{i}_trainer": lambda name: DiscreteMultiActorCritic( - name=name, get_v_critic_net_func=lambda: MyMultiCriticNet(), device="cpu", **ac_conf - ) for i in range(4) - } -elif algorithm == "discrete_maddpg": +if algorithm == "discrete_maddpg": get_policy_func_dict = { f"{algorithm}.{i}": lambda name: DiscretePolicyGradient( name=name, policy_net=MyActorNet()) for i in range(4) } get_trainer_func_dict = { - f"{algorithm}.{i}_trainer": lambda name: DiscreteMADDPG( + f"{algorithm}.{i}_trainer": lambda name: DistributedDiscreteMADDPG( name=name, get_q_critic_net_func=lambda: MyMultiCriticNet(), device="cpu", **ac_conf ) for i in range(4) } diff --git a/maro/rl_v3/tmp_example_single/env_sampler.py b/maro/rl_v3/tmp_example_single/env_sampler.py index c6a7ae423..8821835b8 100644 --- a/maro/rl_v3/tmp_example_single/env_sampler.py +++ b/maro/rl_v3/tmp_example_single/env_sampler.py @@ -5,6 +5,7 @@ from maro.rl_v3.learning import AbsEnvSampler, CacheElement, SimpleAgentWrapper from maro.simulator import Env from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent + from .config import ( action_shaping_conf, env_conf, port_attributes, reward_shaping_conf, state_shaping_conf, vessel_attributes ) diff --git a/maro/rl_v3/tmp_example_single/main.py b/maro/rl_v3/tmp_example_single/main.py index 61aad9c59..aedabd650 100644 --- a/maro/rl_v3/tmp_example_single/main.py +++ b/maro/rl_v3/tmp_example_single/main.py @@ -1,6 +1,7 @@ from maro.rl_v3 import run_workflow_centralized_mode from maro.rl_v3.learning import SimpleAgentWrapper, SimpleTrainerManager from maro.simulator import Env + from .callbacks import cim_post_collect, cim_post_evaluate from .config import algorithm, env_conf, running_mode from .env_sampler import CIMEnvSampler diff --git a/maro/rl_v3/tmp_example_single/nets.py b/maro/rl_v3/tmp_example_single/nets.py index 85188e633..f4a00f4c8 100644 --- a/maro/rl_v3/tmp_example_single/nets.py +++ b/maro/rl_v3/tmp_example_single/nets.py @@ -4,6 +4,7 @@ from torch.optim import Adam, RMSprop from maro.rl_v3.model import DiscretePolicyNet, DiscreteQNet, FullyConnected, VNet + from .config import action_shaping_conf, state_dim q_net_conf = { @@ -50,16 +51,16 @@ def __init__(self) -> None: def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: return self._fc(states) - def step(self, loss: torch.Tensor) -> None: - self._optim.zero_grad() - loss.backward() - self._optim.step() - def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: self._optim.zero_grad() loss.backward() return {name: param.grad for name, param in self.named_parameters()} + def apply_gradients(self, grad: dict) -> None: + for name, param in self.named_parameters(): + param.grad = grad[name] + self._optim.step() + def get_net_state(self) -> object: return {"network": self.state_dict(), "optim": self._optim.state_dict()} @@ -90,16 +91,16 @@ def freeze(self) -> None: def unfreeze(self) -> None: self.unfreeze_all_parameters() - def step(self, loss: torch.Tensor) -> None: - self._actor_optim.zero_grad() - loss.backward() - self._actor_optim.step() - def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: self._actor_optim.zero_grad() loss.backward() return {name: param.grad for name, param in self.named_parameters()} + def apply_gradients(self, grad: dict) -> None: + for name, param in self.named_parameters(): + param.grad = grad[name] + self._actor_optim.step() + def get_net_state(self) -> dict: return { "network": self.state_dict(), @@ -120,16 +121,16 @@ def __init__(self) -> None: def _get_v_values(self, states: torch.Tensor) -> torch.Tensor: return self._critic(states).squeeze(-1) - def step(self, loss: torch.Tensor) -> None: - self._critic_optim.zero_grad() - loss.backward() - self._critic_optim.step() - def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: self._critic_optim.zero_grad() loss.backward() return {name: param.grad for name, param in self.named_parameters()} + def apply_gradients(self, grad: dict) -> None: + for name, param in self.named_parameters(): + param.grad = grad[name] + self._critic_optim.step() + def get_net_state(self) -> dict: return { "network": self.state_dict(), diff --git a/maro/rl_v3/tmp_example_single/policies.py b/maro/rl_v3/tmp_example_single/policies.py index b2344c074..a40189da1 100644 --- a/maro/rl_v3/tmp_example_single/policies.py +++ b/maro/rl_v3/tmp_example_single/policies.py @@ -4,6 +4,7 @@ from maro.rl_v3.policy import DiscretePolicyGradient, ValueBasedPolicy from maro.rl_v3.policy_trainer import DQN, DiscreteActorCritic from maro.rl_v3.workflow import preprocess_get_policy_func_dict + from .config import algorithm, running_mode from .nets import MyActorNet, MyCriticNet, MyQNet diff --git a/maro/rl_v3/workflow.py b/maro/rl_v3/workflow.py index e53222dc3..626c84298 100644 --- a/maro/rl_v3/workflow.py +++ b/maro/rl_v3/workflow.py @@ -10,7 +10,7 @@ def preprocess_get_policy_func_dict( running_mode: str ) -> Dict[str, Callable[[str], RLPolicy]]: if running_mode == "centralized": - print(f"Pre-create the policies under centralized mode.") + print("Pre-create the policies under centralized mode.") policy_dict = {name: get_policy_func(name) for name, get_policy_func in get_policy_func_dict.items()} return {name: lambda name: policy_dict[name] for name in policy_dict} elif running_mode == "decentralized": diff --git a/maro/rl_v3/workflows/grad_worker.py b/maro/rl_v3/workflows/grad_worker.py new file mode 100644 index 000000000..03d6841af --- /dev/null +++ b/maro/rl_v3/workflows/grad_worker.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import time + +from maro.communication import Proxy, SessionMessage +from maro.rl.utils import MsgKey, MsgTag +from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module + +if __name__ == "__main__": + # TODO: WORKERID in docker compose script. + trainer_worker_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "trainer_worker_func_dict") + worker_id = f"GRAD_WORKER.{from_env('WORKERID')}" + num_trainer_workers = from_env("NUMTRAINERWORKERS") if from_env("TRAINERTYPE") == "distributed" else 0 + max_cached_policies = from_env("MAXCACHED", required=False, default=10) + + group = from_env("POLICYGROUP", required=False, default="learn") + policy_dict = {} + active_policies = [] + if num_trainer_workers == 0: + # no remote nodes for trainer workers + num_trainer_workers = len(trainer_worker_func_dict) + + peers = {"trainer": 1, "trainer_workers": num_trainer_workers, "task_queue": 1} + proxy = Proxy( + group, "grad_worker", peers, component_name=worker_id, + redis_address=(from_env("REDISHOST"), from_env("REDISPORT")), + max_peer_discovery_retries=50 + ) + logger = get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), worker_id) + + for msg in proxy.receive(): + if msg.tag == MsgTag.EXIT: + logger.info("Exiting...") + proxy.close() + break + elif msg.tag == MsgTag.COMPUTE_GRAD: + t0 = time.time() + msg_body = {MsgKey.LOSS_INFO: dict(), MsgKey.POLICY_IDS: list()} + for name, batch in msg.body[MsgKey.GRAD_TASK].items(): + if name not in policy_dict: + if len(policy_dict) > max_cached_policies: + # remove the oldest one when size exceeds. + policy_to_remove = active_policies.pop() + policy_dict.pop(policy_to_remove) + # Initialize + policy_dict[name] = trainer_worker_func_dict[name](name) + active_policies.insert(0, name) + logger.info(f"Initialized policies {name}") + + tensor_dict = msg.body[MsgKey.TENSOR][name] + policy_dict[name].set_ops_state_dict(msg.body[MsgKey.POLICY_STATE][name]) + grad_dict = policy_dict[name].get_batch_grad( + batch, tensor_dict, scope=msg.body[MsgKey.GRAD_SCOPE][name]) + msg_body[MsgKey.LOSS_INFO][name] = grad_dict + msg_body[MsgKey.POLICY_IDS].append(name) + # put the latest one to queue head + active_policies.remove(name) + active_policies.insert(0, name) + + logger.debug(f"total policy update time: {time.time() - t0}") + proxy.reply(msg, tag=MsgTag.COMPUTE_GRAD_DONE, body=msg_body) + # release worker at task queue + proxy.isend(SessionMessage( + MsgTag.RELEASE_WORKER, proxy.name, "TASK_QUEUE", body={MsgKey.WORKER_ID: worker_id} + )) + else: + logger.info(f"Wrong message tag: {msg.tag}") + raise TypeError