From c036a4543ebbeeb5b59d94c05f7ff1e7bb3b2969 Mon Sep 17 00:00:00 2001 From: "GQ.Chen" Date: Tue, 14 Dec 2021 16:57:36 +0800 Subject: [PATCH] Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. --- maro/rl/data_parallelism/task_queue.py | 11 ++- maro/rl/utils/message_enums.py | 2 + .../distributed_discrete_maddpg.py | 20 ++++++ maro/rl_v3/policy_trainer/train_worker.py | 30 +++++++- maro/rl_v3/workflows/grad_worker.py | 70 +++++++++++++++++++ 5 files changed, 129 insertions(+), 4 deletions(-) create mode 100644 maro/rl_v3/workflows/grad_worker.py diff --git a/maro/rl/data_parallelism/task_queue.py b/maro/rl/data_parallelism/task_queue.py index 7f88e916d..28e6a4393 100644 --- a/maro/rl/data_parallelism/task_queue.py +++ b/maro/rl/data_parallelism/task_queue.py @@ -5,6 +5,8 @@ from multiprocessing import Manager, Process, Queue, managers from typing import Dict, List +import torch + from maro.communication import Proxy, SessionMessage from maro.rl.utils import MsgKey, MsgTag from maro.utils import DummyLogger, Logger @@ -34,13 +36,18 @@ def request_workers(self, task_queue_server_name="TASK_QUEUE"): return worker_list # TODO: rename this method - def submit(self, worker_id_list: List, batch_list: List, policy_state: Dict, policy_name: str): + def submit( + self, worker_id_list: List, batch_list: List, tensor_dict_list: List, policy_state: Dict, policy_name: str, + scope: str = None + ) -> Dict[str, List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]]: """Learn a batch of data on several grad workers.""" msg_dict = defaultdict(lambda: defaultdict(dict)) loss_info_by_policy = {policy_name: []} - for worker_id, batch in zip(worker_id_list, batch_list): + for worker_id, batch, tensor_dict in zip(worker_id_list, batch_list, tensor_dict_list): msg_dict[worker_id][MsgKey.GRAD_TASK][policy_name] = batch + msg_dict[worker_id][MsgKey.TENSOR][policy_name] = tensor_dict msg_dict[worker_id][MsgKey.POLICY_STATE][policy_name] = policy_state + msg_dict[worker_id][MsgKey.GRAD_SCOPE][policy_name] = scope # data-parallel by multiple remote gradient workers self._proxy.isend(SessionMessage( MsgTag.COMPUTE_GRAD, self._proxy.name, worker_id, body=msg_dict[worker_id])) diff --git a/maro/rl/utils/message_enums.py b/maro/rl/utils/message_enums.py index a4c317059..d1d57e0af 100644 --- a/maro/rl/utils/message_enums.py +++ b/maro/rl/utils/message_enums.py @@ -38,8 +38,10 @@ class MsgKey(Enum): ROLLOUT_INFO = "rollout_info" TRACKER = "tracker" GRAD_TASK = "grad_task" + GRAD_SCOPE = "grad_scope" LOSS_INFO = "loss_info" STATE = "state" + TENSOR = "tensor" POLICY_STATE = "policy_state" EXPLORATION_STEP = "exploration_step" VERSION = "version" diff --git a/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py b/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py index d3c4fe084..089ec5e8d 100644 --- a/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py +++ b/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py @@ -229,6 +229,26 @@ def get_batch_grad( return grad_dict + def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_workers: int) -> List[Dict[str, object]]: + raise NotImplementedError + + def _dispatch_batch(self, batch: MultiTransitionBatch, num_workers: int) -> List[MultiTransitionBatch]: + batch_size = batch.states.shape[0] + assert batch_size >= num_workers, \ + f"Batch size should be greater than or equal to num_workers, but got {batch_size} and {num_workers}." + sub_batch_indexes = [range(batch_size)[i::num_workers] for i in range(num_workers)] + sub_batches = [MultiTransitionBatch( + policy_names=[], + states=batch.states[indexes], + actions=[action[indexes] for action in batch.actions], + rewards=[reward[indexes] for reward in batch.rewards], + terminals=batch.terminals[indexes], + next_states=batch.next_states[indexes], + agent_states=[state[indexes] for state in batch.agent_states], + next_agent_states=[state[indexes] for state in batch.next_agent_states] + ) for indexes in sub_batch_indexes] + return sub_batches + def update_critics(self, next_actions: List[torch.Tensor]) -> None: grads = self._get_batch_grad( self._batch, diff --git a/maro/rl_v3/policy_trainer/train_worker.py b/maro/rl_v3/policy_trainer/train_worker.py index 63c17e66e..268c432ad 100644 --- a/maro/rl_v3/policy_trainer/train_worker.py +++ b/maro/rl_v3/policy_trainer/train_worker.py @@ -5,6 +5,7 @@ from maro.communication import Proxy from maro.rl.data_parallelism import TaskQueueClient +from maro.rl.utils import average_grads from maro.rl_v3.policy import RLPolicy from maro.rl_v3.utils import MultiTransitionBatch, TransitionBatch @@ -33,11 +34,28 @@ def _get_batch_grad( scope: str = "all" ) -> Dict[str, Dict[int, Dict[str, torch.Tensor]]]: if self._enable_data_parallelism: - assert self._task_queue_client is not None - raise NotImplementedError # TODO + gradients = self._remote_learn(batch, tensor_dict, scope) + return average_grads(gradients) else: return self.get_batch_grad(batch, tensor_dict, scope) + def _remote_learn( + self, + batch: MultiTransitionBatch, + tensor_dict: Dict[str, object] = None, + scope: str = "all" + ) -> List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]: + assert self._task_queue_client is not None + worker_id_list = self._task_queue_client.request_workers() + batch_list = self._dispatch_batch(batch, len(worker_id_list)) + # TODO: implement _dispatch_tensor_dict + tensor_dict_list = self._dispatch_tensor_dict(tensor_dict, len(worker_id_list)) + worker_state = self.get_worker_state_dict() + worker_name = self.name + loss_info_by_name = self._task_queue_client.submit( + worker_id_list, batch_list, tensor_dict_list, worker_state, worker_name, scope) + return loss_info_by_name[worker_name] + @abstractmethod def get_batch_grad( self, @@ -47,6 +65,14 @@ def get_batch_grad( ) -> Dict[str, Dict[int, Dict[str, torch.Tensor]]]: raise NotImplementedError + @abstractmethod + def _dispatch_batch(self, batch: MultiTransitionBatch, num_workers: int) -> List[MultiTransitionBatch]: + raise NotImplementedError + + @abstractmethod + def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_workers: int) -> List[Dict[str, object]]: + raise NotImplementedError + def init_data_parallel(self, *args, **kwargs) -> None: """ Initialize a proxy in the policy, for data-parallel training. diff --git a/maro/rl_v3/workflows/grad_worker.py b/maro/rl_v3/workflows/grad_worker.py new file mode 100644 index 000000000..2bdef1f0e --- /dev/null +++ b/maro/rl_v3/workflows/grad_worker.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import time + +from maro.communication import Proxy, SessionMessage +from maro.rl.utils import MsgKey, MsgTag +from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module + +if __name__ == "__main__": + # TODO: WORKERID in docker compose script. + trainer_worker_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "trainer_worker_func_dict") + worker_id = f"GRAD_WORKER.{from_env('WORKERID')}" + num_trainer_workers = from_env("NUMTRAINERWORKERS") if from_env("TRAINERTYPE") == "distributed" else 0 + max_cached_policies = from_env("MAXCACHED", required=False, default=10) + + group = from_env("POLICYGROUP", required=False, default="learn") + policy_dict = {} + active_policies = [] + if num_trainer_workers == 0: + # no remote nodes for trainer workers + num_trainer_workers = len(trainer_worker_func_dict) + + peers = {"trainer": 1, "trainer_workers": num_trainer_workers, "task_queue": 1} + proxy = Proxy( + group, "grad_worker", peers, component_name=worker_id, + redis_address=(from_env("REDISHOST"), from_env("REDISPORT")), + max_peer_discovery_retries=50 + ) + logger = get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), worker_id) + + for msg in proxy.receive(): + if msg.tag == MsgTag.EXIT: + logger.info("Exiting...") + proxy.close() + break + elif msg.tag == MsgTag.COMPUTE_GRAD: + t0 = time.time() + msg_body = {MsgKey.LOSS_INFO: dict(), MsgKey.POLICY_IDS: list()} + for name, batch in msg.body[MsgKey.GRAD_TASK].items(): + if name not in policy_dict: + if len(policy_dict) > max_cached_policies: + # remove the oldest one when size exceeds. + policy_to_remove = active_policies.pop() + policy_dict.pop(policy_to_remove) + # Initialize + policy_dict[name] = trainer_worker_func_dict[name](name) + active_policies.insert(0, name) + logger.info(f"Initialized policies {name}") + + tensor_dict = msg.body[MsgKey.TENSOR][name] + policy_dict[name].set_worker_state_dict(msg.body[MsgKey.POLICY_STATE][name]) + grad_dict = policy_dict[name].get_batch_grad( + batch, tensor_dict, scope=msg.body[MsgKey.GRAD_SCOPE][name]) + msg_body[MsgKey.LOSS_INFO][name] = grad_dict + msg_body[MsgKey.POLICY_IDS].append(name) + # put the latest one to queue head + active_policies.remove(name) + active_policies.insert(0, name) + + logger.debug(f"total policy update time: {time.time() - t0}") + proxy.reply(msg, tag=MsgTag.COMPUTE_GRAD_DONE, body=msg_body) + # release worker at task queue + proxy.isend(SessionMessage( + MsgTag.RELEASE_WORKER, proxy.name, "TASK_QUEUE", body={MsgKey.WORKER_ID: worker_id} + )) + else: + logger.info(f"Wrong message tag: {msg.tag}") + raise TypeError