Skip to content

Commit

Permalink
Rl v3 data parallel grad worker (#432)
Browse files Browse the repository at this point in the history
* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.
  • Loading branch information
buptchan authored Dec 14, 2021
1 parent d05bef5 commit c036a45
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 4 deletions.
11 changes: 9 additions & 2 deletions maro/rl/data_parallelism/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from multiprocessing import Manager, Process, Queue, managers
from typing import Dict, List

import torch

from maro.communication import Proxy, SessionMessage
from maro.rl.utils import MsgKey, MsgTag
from maro.utils import DummyLogger, Logger
Expand Down Expand Up @@ -34,13 +36,18 @@ def request_workers(self, task_queue_server_name="TASK_QUEUE"):
return worker_list

# TODO: rename this method
def submit(self, worker_id_list: List, batch_list: List, policy_state: Dict, policy_name: str):
def submit(
self, worker_id_list: List, batch_list: List, tensor_dict_list: List, policy_state: Dict, policy_name: str,
scope: str = None
) -> Dict[str, List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]]:
"""Learn a batch of data on several grad workers."""
msg_dict = defaultdict(lambda: defaultdict(dict))
loss_info_by_policy = {policy_name: []}
for worker_id, batch in zip(worker_id_list, batch_list):
for worker_id, batch, tensor_dict in zip(worker_id_list, batch_list, tensor_dict_list):
msg_dict[worker_id][MsgKey.GRAD_TASK][policy_name] = batch
msg_dict[worker_id][MsgKey.TENSOR][policy_name] = tensor_dict
msg_dict[worker_id][MsgKey.POLICY_STATE][policy_name] = policy_state
msg_dict[worker_id][MsgKey.GRAD_SCOPE][policy_name] = scope
# data-parallel by multiple remote gradient workers
self._proxy.isend(SessionMessage(
MsgTag.COMPUTE_GRAD, self._proxy.name, worker_id, body=msg_dict[worker_id]))
Expand Down
2 changes: 2 additions & 0 deletions maro/rl/utils/message_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ class MsgKey(Enum):
ROLLOUT_INFO = "rollout_info"
TRACKER = "tracker"
GRAD_TASK = "grad_task"
GRAD_SCOPE = "grad_scope"
LOSS_INFO = "loss_info"
STATE = "state"
TENSOR = "tensor"
POLICY_STATE = "policy_state"
EXPLORATION_STEP = "exploration_step"
VERSION = "version"
Expand Down
20 changes: 20 additions & 0 deletions maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,26 @@ def get_batch_grad(

return grad_dict

def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_workers: int) -> List[Dict[str, object]]:
raise NotImplementedError

def _dispatch_batch(self, batch: MultiTransitionBatch, num_workers: int) -> List[MultiTransitionBatch]:
batch_size = batch.states.shape[0]
assert batch_size >= num_workers, \
f"Batch size should be greater than or equal to num_workers, but got {batch_size} and {num_workers}."
sub_batch_indexes = [range(batch_size)[i::num_workers] for i in range(num_workers)]
sub_batches = [MultiTransitionBatch(
policy_names=[],
states=batch.states[indexes],
actions=[action[indexes] for action in batch.actions],
rewards=[reward[indexes] for reward in batch.rewards],
terminals=batch.terminals[indexes],
next_states=batch.next_states[indexes],
agent_states=[state[indexes] for state in batch.agent_states],
next_agent_states=[state[indexes] for state in batch.next_agent_states]
) for indexes in sub_batch_indexes]
return sub_batches

def update_critics(self, next_actions: List[torch.Tensor]) -> None:
grads = self._get_batch_grad(
self._batch,
Expand Down
30 changes: 28 additions & 2 deletions maro/rl_v3/policy_trainer/train_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from maro.communication import Proxy
from maro.rl.data_parallelism import TaskQueueClient
from maro.rl.utils import average_grads
from maro.rl_v3.policy import RLPolicy
from maro.rl_v3.utils import MultiTransitionBatch, TransitionBatch

Expand Down Expand Up @@ -33,11 +34,28 @@ def _get_batch_grad(
scope: str = "all"
) -> Dict[str, Dict[int, Dict[str, torch.Tensor]]]:
if self._enable_data_parallelism:
assert self._task_queue_client is not None
raise NotImplementedError # TODO
gradients = self._remote_learn(batch, tensor_dict, scope)
return average_grads(gradients)
else:
return self.get_batch_grad(batch, tensor_dict, scope)

def _remote_learn(
self,
batch: MultiTransitionBatch,
tensor_dict: Dict[str, object] = None,
scope: str = "all"
) -> List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]:
assert self._task_queue_client is not None
worker_id_list = self._task_queue_client.request_workers()
batch_list = self._dispatch_batch(batch, len(worker_id_list))
# TODO: implement _dispatch_tensor_dict
tensor_dict_list = self._dispatch_tensor_dict(tensor_dict, len(worker_id_list))
worker_state = self.get_worker_state_dict()
worker_name = self.name
loss_info_by_name = self._task_queue_client.submit(
worker_id_list, batch_list, tensor_dict_list, worker_state, worker_name, scope)
return loss_info_by_name[worker_name]

@abstractmethod
def get_batch_grad(
self,
Expand All @@ -47,6 +65,14 @@ def get_batch_grad(
) -> Dict[str, Dict[int, Dict[str, torch.Tensor]]]:
raise NotImplementedError

@abstractmethod
def _dispatch_batch(self, batch: MultiTransitionBatch, num_workers: int) -> List[MultiTransitionBatch]:
raise NotImplementedError

@abstractmethod
def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_workers: int) -> List[Dict[str, object]]:
raise NotImplementedError

def init_data_parallel(self, *args, **kwargs) -> None:
"""
Initialize a proxy in the policy, for data-parallel training.
Expand Down
70 changes: 70 additions & 0 deletions maro/rl_v3/workflows/grad_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import time

from maro.communication import Proxy, SessionMessage
from maro.rl.utils import MsgKey, MsgTag
from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module

if __name__ == "__main__":
# TODO: WORKERID in docker compose script.
trainer_worker_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "trainer_worker_func_dict")
worker_id = f"GRAD_WORKER.{from_env('WORKERID')}"
num_trainer_workers = from_env("NUMTRAINERWORKERS") if from_env("TRAINERTYPE") == "distributed" else 0
max_cached_policies = from_env("MAXCACHED", required=False, default=10)

group = from_env("POLICYGROUP", required=False, default="learn")
policy_dict = {}
active_policies = []
if num_trainer_workers == 0:
# no remote nodes for trainer workers
num_trainer_workers = len(trainer_worker_func_dict)

peers = {"trainer": 1, "trainer_workers": num_trainer_workers, "task_queue": 1}
proxy = Proxy(
group, "grad_worker", peers, component_name=worker_id,
redis_address=(from_env("REDISHOST"), from_env("REDISPORT")),
max_peer_discovery_retries=50
)
logger = get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), worker_id)

for msg in proxy.receive():
if msg.tag == MsgTag.EXIT:
logger.info("Exiting...")
proxy.close()
break
elif msg.tag == MsgTag.COMPUTE_GRAD:
t0 = time.time()
msg_body = {MsgKey.LOSS_INFO: dict(), MsgKey.POLICY_IDS: list()}
for name, batch in msg.body[MsgKey.GRAD_TASK].items():
if name not in policy_dict:
if len(policy_dict) > max_cached_policies:
# remove the oldest one when size exceeds.
policy_to_remove = active_policies.pop()
policy_dict.pop(policy_to_remove)
# Initialize
policy_dict[name] = trainer_worker_func_dict[name](name)
active_policies.insert(0, name)
logger.info(f"Initialized policies {name}")

tensor_dict = msg.body[MsgKey.TENSOR][name]
policy_dict[name].set_worker_state_dict(msg.body[MsgKey.POLICY_STATE][name])
grad_dict = policy_dict[name].get_batch_grad(
batch, tensor_dict, scope=msg.body[MsgKey.GRAD_SCOPE][name])
msg_body[MsgKey.LOSS_INFO][name] = grad_dict
msg_body[MsgKey.POLICY_IDS].append(name)
# put the latest one to queue head
active_policies.remove(name)
active_policies.insert(0, name)

logger.debug(f"total policy update time: {time.time() - t0}")
proxy.reply(msg, tag=MsgTag.COMPUTE_GRAD_DONE, body=msg_body)
# release worker at task queue
proxy.isend(SessionMessage(
MsgTag.RELEASE_WORKER, proxy.name, "TASK_QUEUE", body={MsgKey.WORKER_ID: worker_id}
))
else:
logger.info(f"Wrong message tag: {msg.tag}")
raise TypeError

0 comments on commit c036a45

Please sign in to comment.