Skip to content

Commit

Permalink
Calculating advantages / returns (#454)
Browse files Browse the repository at this point in the history
* V1.0

* Complete DDPG
  • Loading branch information
lihuoran authored Jan 11, 2022
1 parent 0bdc230 commit 1610a3f
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 46 deletions.
39 changes: 22 additions & 17 deletions maro/rl_v3/training/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,12 @@ def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_sub_batches:
raise NotImplementedError

def _get_critic_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
states = ndarray_to_tensor(batch.states, self._device) # s

self._policy.train()
self._v_critic_net.train()

states = ndarray_to_tensor(batch.states, self._device) # s
state_values = self._v_critic_net.v_values(states)
values = state_values.detach().numpy()
values = np.concatenate([values, values[-1:]])
rewards = np.concatenate([batch.rewards, values[-1:]])

returns = ndarray_to_tensor(discount_cumsum(rewards, self._reward_discount)[:-1], self._device)
returns = ndarray_to_tensor(batch.returns, self._device)
critic_loss = self._critic_loss_func(state_values, returns)

return self._v_critic_net.get_gradients(critic_loss * self._critic_loss_coef)
Expand All @@ -136,21 +131,14 @@ def _get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:

states = ndarray_to_tensor(batch.states, self._device) # s
actions = ndarray_to_tensor(batch.actions, self._device).long() # a
advantages = ndarray_to_tensor(self._batch.advantages, self._device)

if self._clip_ratio is not None:
self._policy.eval()
logps_old = self._policy.get_state_action_logps(states, actions)
else:
logps_old = None

state_values = self._v_critic_net.v_values(states)
values = state_values.detach().numpy()
values = np.concatenate([values, values[-1:]])
rewards = np.concatenate([batch.rewards, values[-1:]])

deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s)
advantages = ndarray_to_tensor(discount_cumsum(deltas, self._reward_discount * self._lam), self._device)

action_probs = self._policy.get_action_probs(states)
logps = torch.log(action_probs.gather(1, actions).squeeze())
logps = torch.clamp(logps, min=self._min_logp, max=.0)
Expand Down Expand Up @@ -187,6 +175,23 @@ def set_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None:
if scope in ("all", "critic"):
self._v_critic_net.set_net_state(ops_state_dict["critic_state"])

def set_batch(self, batch: TransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch

# Preprocess returns
self._batch.calc_returns(self._reward_discount)

# Preprocess advantages
states = ndarray_to_tensor(batch.states, self._device) # s
state_values = self._v_critic_net.v_values(states)
values = state_values.detach().numpy()
values = np.concatenate([values, values[-1:]])
rewards = np.concatenate([batch.rewards, values[-1:]])
deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s)
advantages = discount_cumsum(deltas, self._reward_discount * self._lam)
self._batch.advantages = advantages


class DiscreteActorCritic(SingleTrainer):
"""Actor Critic algorithm with separate policy and value models.
Expand Down Expand Up @@ -234,8 +239,8 @@ def record(self, exp_element: ExpElement) -> None:
def _get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DiscreteActorCriticOps(**self._ops_params)

def _get_batch(self, agent_name: str, batch_size: int = None) -> TransitionBatch:
return self._replay_memory_dict[agent_name].sample(batch_size if batch_size is not None else self._batch_size)
def _get_batch(self, agent_name: str) -> TransitionBatch:
return self._replay_memory_dict[agent_name].sample(-1) # Use all entries in the replay memory

async def train_step(self):
for agent_name in self._replay_memory_dict:
Expand Down
88 changes: 61 additions & 27 deletions maro/rl_v3/training/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# TODO: DDPG has net been tested in a real test case

from dataclasses import dataclass
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional

import numpy as np
import torch

from maro.rl_v3.learning import ExpElement
from maro.rl_v3.model import QNet
from maro.rl_v3.policy import ContinuousRLPolicy
from maro.rl_v3.training import AbsTrainOps, RandomReplayMemory, SingleTrainer, TrainerParams
Expand All @@ -17,6 +19,21 @@

@dataclass
class DDPGParams(TrainerParams):
"""
get_q_critic_net_func (Callable[[], QNet]): Function to get Q critic net.
reward_discount (float): Reward decay as defined in standard RL terminology.
num_epochs (int): Number of training epochs per call to ``learn``. Defaults to 1.
update_target_every (int): Number of training rounds between policy target model updates.
q_value_loss_cls: A string indicating a loss class provided by torch.nn or a custom loss class for
the Q-value loss. If it is a string, it must be a key in ``TORCH_LOSS``. Defaults to "mse".
soft_update_coef (float): Soft update coefficient, e.g., target_model = (soft_update_coef) * eval_model +
(1-soft_update_coef) * target_model. Defaults to 1.0.
critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 0.1.
random_overwrite (bool): This specifies overwrite behavior when the replay memory capacity is reached. If True,
overwrite positions will be selected randomly. Otherwise, overwrites will occur sequentially with
wrap-around. Defaults to False.
"""
get_q_critic_net_func: Callable[[], QNet] = None,
reward_discount: float = 0.9
num_epochs: int = 1
update_target_every: int = 5
Expand All @@ -25,6 +42,20 @@ class DDPGParams(TrainerParams):
critic_loss_coef: float = 0.1
random_overwrite: bool = False

def __post_init__(self) -> None:
assert self.get_q_critic_net_func is not None

def extract_ops_params(self) -> Dict[str, object]:
return {
"device": self.device,
"enable_data_parallelism": self.enable_data_parallelism,
"get_q_critic_net_func": self.get_q_critic_net_func,
"reward_discount": self.reward_discount,
"q_value_loss_cls": self.q_value_loss_cls,
"soft_update_coef": self.soft_update_coef,
"critic_loss_coef": self.critic_loss_coef,
}


class DDPGOps(AbsTrainOps):
def __init__(
Expand Down Expand Up @@ -159,65 +190,68 @@ def soft_update_target(self) -> None:
self._target_policy.soft_update(self._policy, self._soft_update_coef)
self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef)

def set_batch(self, batch: TransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch


class DDPG(SingleTrainer):
"""The Deep Deterministic Policy Gradient (DDPG) algorithm.
References:
https://arxiv.org/pdf/1509.02971.pdf
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg
Args:
name (str): Unique identifier for the trainer.
policy_creator (Dict[str, Callable[[str], DiscretePolicyGradient]]): Dict of functions that used to
create policies.
device (str): Identifier for the torch device. The policy will be moved to the specified device. If it is
None, the device will be set to "cpu" if cuda is unavailable and "cuda" otherwise. Defaults to None.
enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. Defaults to False.
replay_memory_capacity (int): Capacity of the replay memory. Defaults to 10000.
random_overwrite (bool): This specifies overwrite behavior when the replay memory capacity is reached. If True,
overwrite positions will be selected randomly. Otherwise, overwrites will occur sequentially with
wrap-around. Defaults to False.
num_epochs (int): Number of training epochs per call to ``learn``. Defaults to 1.
update_target_every (int): Number of training rounds between policy target model updates.
reward_discount (float): Reward decay as defined in standard RL terminology.
q_value_loss_cls: A string indicating a loss class provided by torch.nn or a custom loss class for
the Q-value loss. If it is a string, it must be a key in ``TORCH_LOSS``. Defaults to "mse".
soft_update_coef (float): Soft update coefficient, e.g., target_model = (soft_update_coef) * eval_model +
(1-soft_update_coef) * target_model. Defaults to 1.0.
critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 0.1.
"""

def __init__(self, name: str, params: DDPGParams) -> None:
super(DDPG, self).__init__(name, params)
self._params = params
self._policy_version = self._target_policy_version = 0
self._q_net_version = self._target_q_net_version = 0
self._ops_name = f"{self._name}.ops"

self._replay_memory: Optional[RandomReplayMemory] = None

def build(self) -> None:
self._ops_params = {
"get_policy_func": self._get_policy_func,
**self._params.extract_ops_params(),
}

self._ops = self.get_ops(self._ops_name)

self._replay_memory = RandomReplayMemory(
capacity=self._params.replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim,
random_overwrite=self._params.random_overwrite
)

def record(self, exp_element: ExpElement) -> None:
for agent_name in exp_element.agent_names:
transition_batch = TransitionBatch(
states=np.expand_dims(exp_element.agent_state_dict[agent_name], axis=0),
actions=np.expand_dims(exp_element.action_dict[agent_name], axis=0),
rewards=np.array([exp_element.reward_dict[agent_name]]),
terminals=np.array([exp_element.terminal_dict[agent_name]]),
next_states=np.expand_dims(
exp_element.next_agent_state_dict.get(agent_name, exp_element.agent_state_dict[agent_name]),
axis=0,
),
)
self._replay_memory.put(transition_batch)

def _get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DDPGOps(**self._ops_params)

def _get_batch(self, batch_size: int = None) -> TransitionBatch:
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)

async def train_step(self) -> None:
for _ in range(self._params.num_epochs):
await self._ops.set_batch(self._get_batch())
await self._ops.update()
self._policy_version += 1
if self._policy_version - self._target_policy_version == self._params.update_target_every:
await self._ops.soft_update_target()
self._target_policy_version = self._policy_version

self._policy_version += 1
if self._policy_version - self._target_policy_version == self._params.update_target_every:
await self._ops.soft_update_target()
self._target_policy_version = self._policy_version
4 changes: 4 additions & 0 deletions maro/rl_v3/training/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def update(self) -> None:
def soft_update_target(self) -> None:
self._target_policy.soft_update(self._policy, self._soft_update_coef)

def set_batch(self, batch: TransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch


class DQN(SingleTrainer):
"""The Deep-Q-Networks algorithm.
Expand Down
4 changes: 4 additions & 0 deletions maro/rl_v3/training/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ def soft_update_target(self) -> None:
if not self._shared_critic:
self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef)

def set_batch(self, batch: MultiTransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch


class DiscreteMADDPG(MultiTrainer):
def __init__(self, name: str, params: DiscreteMADDPGParams) -> None:
Expand Down
4 changes: 2 additions & 2 deletions maro/rl_v3/training/train_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def set_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None:
"""Set ops's state."""
raise NotImplementedError

@abstractmethod
def set_batch(self, batch: AbsTransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch
raise NotImplementedError

def get_policy_state(self) -> object:
return self._policy.name, self._policy.get_state()
Expand Down
6 changes: 6 additions & 0 deletions maro/rl_v3/utils/transition_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from . import discount_cumsum
from .objects import SHAPE_CHECK_FLAG


Expand All @@ -13,6 +14,8 @@ class TransitionBatch:
rewards: np.ndarray # 1D
next_states: np.ndarray # 2D
terminals: np.ndarray # 1D
returns: np.ndarray = None # 1D
advantages: np.ndarray = None # 1D

def __post_init__(self) -> None:
if SHAPE_CHECK_FLAG:
Expand All @@ -22,6 +25,9 @@ def __post_init__(self) -> None:
assert self.next_states.shape == self.states.shape
assert len(self.terminals.shape) == 1 and self.terminals.shape[0] == self.states.shape[0]

def calc_returns(self, discount_factor: float) -> None:
self.returns = discount_cumsum(self.rewards, discount_factor)


@dataclass
class MultiTransitionBatch:
Expand Down

0 comments on commit 1610a3f

Please sign in to comment.