Skip to content

Commit

Permalink
Complement docstring for task queue and trainer.
Browse files Browse the repository at this point in the history
  • Loading branch information
buptchan committed Dec 15, 2021
1 parent c036a45 commit 8bdec10
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 3 deletions.
20 changes: 19 additions & 1 deletion maro/rl/data_parallelism/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def submit(
self, worker_id_list: List, batch_list: List, tensor_dict_list: List, policy_state: Dict, policy_name: str,
scope: str = None
) -> Dict[str, List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]]:
"""Learn a batch of data on several grad workers."""
"""Learn a batch of data on several grad workers.
For each policy, send a list of batch and state to grad workers, and receive a list of gradients.
The results is actually from train worker's `get_batch_grad()` method, with type:
Dict[str, Dict[int, Dict[str, torch.Tensor]]], which means {scope: {worker_id: {param_name: grad_value}}}"""
msg_dict = defaultdict(lambda: defaultdict(dict))
loss_info_by_policy = {policy_name: []}
for worker_id, batch, tensor_dict in zip(worker_id_list, batch_list, tensor_dict_list):
Expand Down Expand Up @@ -80,6 +83,21 @@ def task_queue(
proxy_kwargs: dict = {},
logger: Logger = DummyLogger()
):
"""The queue to manage data parallel tasks. Task queue communicates with gradient workers,
maintaing the busy/idle status of workers. Clients send requests to task queue, and task queue
will assign available workers to the requests. Task queue follows the `producer-consumer` model,
consisting of two queues: task_pending, task_assigned. Besides, task queue supports task priority,
adding/deleting workers.
Args:
worker_ids (List[int]): Worker ids to initialize.
num_hosts (int): The number of policy hosts. Will be renamed in RL v3.
num_policies (int): The number of policies.
single_task_limit (float): The limit resource proportion for a single task to assign. Defaults to 0.5
group (str): Group name to initialize proxy. Defaults to DEFAULT_POLICY_GROUP.
proxy_kwargs (dict): Keyword arguments for proxy. Defaults to empty dict.
logger (Logger): Defaults to DummyLogger().
"""
num_workers = len(worker_ids)
if num_hosts == 0:
# for multi-process mode
Expand Down
4 changes: 2 additions & 2 deletions maro/rl_v3/policy_trainer/abs_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


class AbsTrainer(object, metaclass=ABCMeta):
"""
Policy trainer used to train policies.
"""Policy trainer used to train policies. Trainer maintains several train workers and
controls training logics of them, while train workers take charge of specific policy updating.
"""
def __init__(
self,
Expand Down
19 changes: 19 additions & 0 deletions maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@


class DiscreteMADDPGWorker(MultiTrainWorker):
"""The discrete variant of MADDPG algorithm.
Args:
name (str): Name of the worker.
device (torch.device): Which device to use.
reward_discount (float): The discount factor of feature reward.
get_q_critic_net_func (Callable[[], MultiQNet): Function to get Q critic net.
shared_critic (bool): Whether to share critic for actors. Defaults to False.
critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 1.0.
soft_update_coef (float): Soft update coefficient, e.g., target_model = (soft_update_coef) * eval_model +
(1-soft_update_coef) * target_model. Defaults to 1.0.
update_target_every (int): Number of training rounds between policy target model updates. Defaults to 5.
q_value_loss_func (Callable): The loss function provided by torch.nn or a custom loss class for the
Q-value loss. Defaults to None.
enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. Defaults to False.
Reference:
Paper: http://papers.nips.cc/paper/by-source-2017-3193
Code: https://github.com/openai/maddpg
"""
def __init__(
self,
name: str,
Expand Down
12 changes: 12 additions & 0 deletions maro/rl_v3/policy_trainer/train_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@


class AbsTrainWorker(object, metaclass=ABCMeta):
"""The basic component for training a policy, which mainly takes charge of gradient computation and policy update.
In trainer, train worker hosts a policy, and trainer hosts several train workers. In gradient workers,
the train worker is an atomic representation of a policy, to perform parallel gradient computing.
"""
def __init__(
self,
name: str,
Expand Down Expand Up @@ -45,6 +49,11 @@ def _remote_learn(
tensor_dict: Dict[str, object] = None,
scope: str = "all"
) -> List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]:
"""Learn a batch of experience data from remote gradient workers.
The task queue client will first request available gradient workers from task queue. If all workers are busy,
it will keep waiting until at least 1 worker is available. Then the task queue client submits batch and state
to the assigned workers to compute gradients.
"""
assert self._task_queue_client is not None
worker_id_list = self._task_queue_client.request_workers()
batch_list = self._dispatch_batch(batch, len(worker_id_list))
Expand All @@ -67,6 +76,9 @@ def get_batch_grad(

@abstractmethod
def _dispatch_batch(self, batch: MultiTransitionBatch, num_workers: int) -> List[MultiTransitionBatch]:
"""Split experience data batch to several parts.
For on-policy algorithms, like PG, the batch is splitted into several complete trajectories.
For off-policy algorithms, like DQN, the batch is treated as independent data points and splitted evenly."""
raise NotImplementedError

@abstractmethod
Expand Down

0 comments on commit 8bdec10

Please sign in to comment.