Skip to content

Commit

Permalink
fixed merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
ysqyang committed Dec 27, 2021
2 parents 9ebabaa + 8e4ad49 commit f293980
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 14 deletions.
27 changes: 15 additions & 12 deletions maro/rl_v3/policy_trainer/abs_train_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,16 @@ class AbsTrainOps(object, metaclass=ABCMeta):
"""
def __init__(
self,
name: str,
device: torch.device,
enable_data_parallelism: bool = False
) -> None:
super(AbsTrainOps, self).__init__()
self._name = None
self._name = name
self._enable_data_parallelism = enable_data_parallelism
self._task_queue_client: Optional[TaskQueueClient] = None
self._device = device

@abstractmethod
def get_models(self):
raise NotImplementedError

@property
def name(self) -> str:
return self._name
Expand Down Expand Up @@ -129,15 +126,21 @@ def set_ops_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None:
class SingleTrainOps(AbsTrainOps, metaclass=ABCMeta):
def __init__(
self,
policy: RLPolicy,
name: str,
device: torch.device,
enable_data_parallelism: bool = False
) -> None:
super(SingleTrainOps, self).__init__(device, enable_data_parallelism)
self._policy = policy
self._name = self._policy.name
self._policy.to_device(self._device)
super(SingleTrainOps, self).__init__(name, device, enable_data_parallelism)
self._batch: Optional[TransitionBatch] = None
self._policy: Optional[RLPolicy] = None

def register_policy(self, policy: RLPolicy) -> None:
policy.to_device(self._device)
self._register_policy_impl(policy)

@abstractmethod
def _register_policy_impl(self, policy: RLPolicy) -> None:
raise NotImplementedError

def set_batch(self, batch: TransitionBatch) -> None:
self._batch = batch
Expand All @@ -152,11 +155,11 @@ def set_policy_state(self, policy_state: object) -> None:
class MultiTrainOps(AbsTrainOps, metaclass=ABCMeta):
def __init__(
self,
policies: List[RLPolicy],
name: str,
device: torch.device,
enable_data_parallelism: bool = False
) -> None:
super(MultiTrainOps, self).__init__(device, enable_data_parallelism)
super(MultiTrainOps, self).__init__(name, device, enable_data_parallelism)
self._batch: Optional[MultiTransitionBatch] = None
self._policies: Dict[int, RLPolicy] = {}
self._indexes: List[int] = []
Expand Down
2 changes: 1 addition & 1 deletion maro/rl_v3/policy_trainer/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from maro.rl.utils import discount_cumsum
from maro.rl_v3.model import VNet
from maro.rl_v3.policy import DiscretePolicyGradient
from maro.rl_v3.policy import DiscretePolicyGradient, RLPolicy
from maro.rl_v3.replay_memory import FIFOReplayMemory
from maro.rl_v3.utils import TransitionBatch, ndarray_to_tensor

Expand Down
2 changes: 1 addition & 1 deletion maro/rl_v3/policy_trainer/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from maro.rl_v3.policy import ValueBasedPolicy
from maro.rl_v3.policy import RLPolicy, ValueBasedPolicy
from maro.rl_v3.replay_memory import RandomReplayMemory
from maro.rl_v3.utils import TransitionBatch, ndarray_to_tensor
from maro.utils import clone
Expand Down
70 changes: 70 additions & 0 deletions maro/rl_v3/workflows/grad_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import time

from maro.communication import Proxy, SessionMessage
from maro.rl.utils import MsgKey, MsgTag
from maro.rl.workflows.helpers import from_env, get_logger, get_scenario_module

if __name__ == "__main__":
# TODO: WORKERID in docker compose script.
trainer_worker_func_dict = getattr(get_scenario_module(from_env("SCENARIODIR")), "trainer_worker_func_dict")
worker_id = f"GRAD_WORKER.{from_env('WORKERID')}"
num_trainer_workers = from_env("NUMTRAINERWORKERS") if from_env("TRAINERTYPE") == "distributed" else 0
max_cached_policies = from_env("MAXCACHED", required=False, default=10)

group = from_env("POLICYGROUP", required=False, default="learn")
policy_dict = {}
active_policies = []
if num_trainer_workers == 0:
# no remote nodes for trainer workers
num_trainer_workers = len(trainer_worker_func_dict)

peers = {"trainer": 1, "trainer_workers": num_trainer_workers, "task_queue": 1}
proxy = Proxy(
group, "grad_worker", peers, component_name=worker_id,
redis_address=(from_env("REDISHOST"), from_env("REDISPORT")),
max_peer_discovery_retries=50
)
logger = get_logger(from_env("LOGDIR", required=False, default=os.getcwd()), from_env("JOB"), worker_id)

for msg in proxy.receive():
if msg.tag == MsgTag.EXIT:
logger.info("Exiting...")
proxy.close()
break
elif msg.tag == MsgTag.COMPUTE_GRAD:
t0 = time.time()
msg_body = {MsgKey.LOSS_INFO: dict(), MsgKey.POLICY_IDS: list()}
for name, batch in msg.body[MsgKey.GRAD_TASK].items():
if name not in policy_dict:
if len(policy_dict) > max_cached_policies:
# remove the oldest one when size exceeds.
policy_to_remove = active_policies.pop()
policy_dict.pop(policy_to_remove)
# Initialize
policy_dict[name] = trainer_worker_func_dict[name](name)
active_policies.insert(0, name)
logger.info(f"Initialized policies {name}")

tensor_dict = msg.body[MsgKey.TENSOR][name]
policy_dict[name].set_ops_state_dict(msg.body[MsgKey.POLICY_STATE][name])
grad_dict = policy_dict[name].get_batch_grad(
batch, tensor_dict, scope=msg.body[MsgKey.GRAD_SCOPE][name])
msg_body[MsgKey.LOSS_INFO][name] = grad_dict
msg_body[MsgKey.POLICY_IDS].append(name)
# put the latest one to queue head
active_policies.remove(name)
active_policies.insert(0, name)

logger.debug(f"total policy update time: {time.time() - t0}")
proxy.reply(msg, tag=MsgTag.COMPUTE_GRAD_DONE, body=msg_body)
# release worker at task queue
proxy.isend(SessionMessage(
MsgTag.RELEASE_WORKER, proxy.name, "TASK_QUEUE", body={MsgKey.WORKER_ID: worker_id}
))
else:
logger.info(f"Wrong message tag: {msg.tag}")
raise TypeError

0 comments on commit f293980

Please sign in to comment.