diff --git a/maro/rl/data_parallelism/task_queue.py b/maro/rl/data_parallelism/task_queue.py index 28e6a4393..2efb46597 100644 --- a/maro/rl/data_parallelism/task_queue.py +++ b/maro/rl/data_parallelism/task_queue.py @@ -40,7 +40,10 @@ def submit( self, worker_id_list: List, batch_list: List, tensor_dict_list: List, policy_state: Dict, policy_name: str, scope: str = None ) -> Dict[str, List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]]: - """Learn a batch of data on several grad workers.""" + """Learn a batch of data on several grad workers. + For each policy, send a list of batch and state to grad workers, and receive a list of gradients. + The results is actually from train worker's `get_batch_grad()` method, with type: + Dict[str, Dict[int, Dict[str, torch.Tensor]]], which means {scope: {worker_id: {param_name: grad_value}}}""" msg_dict = defaultdict(lambda: defaultdict(dict)) loss_info_by_policy = {policy_name: []} for worker_id, batch, tensor_dict in zip(worker_id_list, batch_list, tensor_dict_list): @@ -80,6 +83,21 @@ def task_queue( proxy_kwargs: dict = {}, logger: Logger = DummyLogger() ): + """The queue to manage data parallel tasks. Task queue communicates with gradient workers, + maintaing the busy/idle status of workers. Clients send requests to task queue, and task queue + will assign available workers to the requests. Task queue follows the `producer-consumer` model, + consisting of two queues: task_pending, task_assigned. Besides, task queue supports task priority, + adding/deleting workers. + + Args: + worker_ids (List[int]): Worker ids to initialize. + num_hosts (int): The number of policy hosts. Will be renamed in RL v3. + num_policies (int): The number of policies. + single_task_limit (float): The limit resource proportion for a single task to assign. Defaults to 0.5 + group (str): Group name to initialize proxy. Defaults to DEFAULT_POLICY_GROUP. + proxy_kwargs (dict): Keyword arguments for proxy. Defaults to empty dict. + logger (Logger): Defaults to DummyLogger(). + """ num_workers = len(worker_ids) if num_hosts == 0: # for multi-process mode diff --git a/maro/rl_v3/policy_trainer/abs_trainer.py b/maro/rl_v3/policy_trainer/abs_trainer.py index 9cc37a964..06d1fd5aa 100644 --- a/maro/rl_v3/policy_trainer/abs_trainer.py +++ b/maro/rl_v3/policy_trainer/abs_trainer.py @@ -10,8 +10,8 @@ class AbsTrainer(object, metaclass=ABCMeta): - """ - Policy trainer used to train policies. + """Policy trainer used to train policies. Trainer maintains several train workers and + controls training logics of them, while train workers take charge of specific policy updating. """ def __init__( self, diff --git a/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py b/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py index 089ec5e8d..2c568442e 100644 --- a/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py +++ b/maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py @@ -13,6 +13,25 @@ class DiscreteMADDPGWorker(MultiTrainWorker): + """The discrete variant of MADDPG algorithm. + Args: + name (str): Name of the worker. + device (torch.device): Which device to use. + reward_discount (float): The discount factor of feature reward. + get_q_critic_net_func (Callable[[], MultiQNet): Function to get Q critic net. + shared_critic (bool): Whether to share critic for actors. Defaults to False. + critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 1.0. + soft_update_coef (float): Soft update coefficient, e.g., target_model = (soft_update_coef) * eval_model + + (1-soft_update_coef) * target_model. Defaults to 1.0. + update_target_every (int): Number of training rounds between policy target model updates. Defaults to 5. + q_value_loss_func (Callable): The loss function provided by torch.nn or a custom loss class for the + Q-value loss. Defaults to None. + enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. Defaults to False. + + Reference: + Paper: http://papers.nips.cc/paper/by-source-2017-3193 + Code: https://github.com/openai/maddpg + """ def __init__( self, name: str, diff --git a/maro/rl_v3/policy_trainer/train_worker.py b/maro/rl_v3/policy_trainer/train_worker.py index 268c432ad..09d2ef47c 100644 --- a/maro/rl_v3/policy_trainer/train_worker.py +++ b/maro/rl_v3/policy_trainer/train_worker.py @@ -11,6 +11,10 @@ class AbsTrainWorker(object, metaclass=ABCMeta): + """The basic component for training a policy, which mainly takes charge of gradient computation and policy update. + In trainer, train worker hosts a policy, and trainer hosts several train workers. In gradient workers, + the train worker is an atomic representation of a policy, to perform parallel gradient computing. + """ def __init__( self, name: str, @@ -45,6 +49,11 @@ def _remote_learn( tensor_dict: Dict[str, object] = None, scope: str = "all" ) -> List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]: + """Learn a batch of experience data from remote gradient workers. + The task queue client will first request available gradient workers from task queue. If all workers are busy, + it will keep waiting until at least 1 worker is available. Then the task queue client submits batch and state + to the assigned workers to compute gradients. + """ assert self._task_queue_client is not None worker_id_list = self._task_queue_client.request_workers() batch_list = self._dispatch_batch(batch, len(worker_id_list)) @@ -67,6 +76,9 @@ def get_batch_grad( @abstractmethod def _dispatch_batch(self, batch: MultiTransitionBatch, num_workers: int) -> List[MultiTransitionBatch]: + """Split experience data batch to several parts. + For on-policy algorithms, like PG, the batch is splitted into several complete trajectories. + For off-policy algorithms, like DQN, the batch is treated as independent data points and splitted evenly.""" raise NotImplementedError @abstractmethod