Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculating advantages / returns #454

Merged
merged 2 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions maro/rl_v3/training/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,12 @@ def _dispatch_tensor_dict(self, tensor_dict: Dict[str, object], num_sub_batches:
raise NotImplementedError

def _get_critic_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
states = ndarray_to_tensor(batch.states, self._device) # s

self._policy.train()
self._v_critic_net.train()

states = ndarray_to_tensor(batch.states, self._device) # s
state_values = self._v_critic_net.v_values(states)
values = state_values.detach().numpy()
values = np.concatenate([values, values[-1:]])
rewards = np.concatenate([batch.rewards, values[-1:]])

returns = ndarray_to_tensor(discount_cumsum(rewards, self._reward_discount)[:-1], self._device)
returns = ndarray_to_tensor(batch.returns, self._device)
critic_loss = self._critic_loss_func(state_values, returns)

return self._v_critic_net.get_gradients(critic_loss * self._critic_loss_coef)
Expand All @@ -136,21 +131,14 @@ def _get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:

states = ndarray_to_tensor(batch.states, self._device) # s
actions = ndarray_to_tensor(batch.actions, self._device).long() # a
advantages = ndarray_to_tensor(self._batch.advantages, self._device)

if self._clip_ratio is not None:
self._policy.eval()
logps_old = self._policy.get_state_action_logps(states, actions)
else:
logps_old = None

state_values = self._v_critic_net.v_values(states)
values = state_values.detach().numpy()
values = np.concatenate([values, values[-1:]])
rewards = np.concatenate([batch.rewards, values[-1:]])

deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s)
advantages = ndarray_to_tensor(discount_cumsum(deltas, self._reward_discount * self._lam), self._device)

action_probs = self._policy.get_action_probs(states)
logps = torch.log(action_probs.gather(1, actions).squeeze())
logps = torch.clamp(logps, min=self._min_logp, max=.0)
Expand Down Expand Up @@ -187,6 +175,23 @@ def set_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None:
if scope in ("all", "critic"):
self._v_critic_net.set_net_state(ops_state_dict["critic_state"])

def set_batch(self, batch: TransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch

# Preprocess returns
self._batch.calc_returns(self._reward_discount)

# Preprocess advantages
states = ndarray_to_tensor(batch.states, self._device) # s
state_values = self._v_critic_net.v_values(states)
values = state_values.detach().numpy()
values = np.concatenate([values, values[-1:]])
rewards = np.concatenate([batch.rewards, values[-1:]])
deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s)
advantages = discount_cumsum(deltas, self._reward_discount * self._lam)
self._batch.advantages = advantages


class DiscreteActorCritic(SingleTrainer):
"""Actor Critic algorithm with separate policy and value models.
Expand Down Expand Up @@ -234,8 +239,8 @@ def record(self, exp_element: ExpElement) -> None:
def _get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DiscreteActorCriticOps(**self._ops_params)

def _get_batch(self, agent_name: str, batch_size: int = None) -> TransitionBatch:
return self._replay_memory_dict[agent_name].sample(batch_size if batch_size is not None else self._batch_size)
def _get_batch(self, agent_name: str) -> TransitionBatch:
return self._replay_memory_dict[agent_name].sample(-1) # Use all entries in the replay memory

async def train_step(self):
for agent_name in self._replay_memory_dict:
Expand Down
88 changes: 61 additions & 27 deletions maro/rl_v3/training/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# TODO: DDPG has net been tested in a real test case

from dataclasses import dataclass
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional

import numpy as np
import torch

from maro.rl_v3.learning import ExpElement
from maro.rl_v3.model import QNet
from maro.rl_v3.policy import ContinuousRLPolicy
from maro.rl_v3.training import AbsTrainOps, RandomReplayMemory, SingleTrainer, TrainerParams
Expand All @@ -17,6 +19,21 @@

@dataclass
class DDPGParams(TrainerParams):
"""
get_q_critic_net_func (Callable[[], QNet]): Function to get Q critic net.
reward_discount (float): Reward decay as defined in standard RL terminology.
num_epochs (int): Number of training epochs per call to ``learn``. Defaults to 1.
update_target_every (int): Number of training rounds between policy target model updates.
q_value_loss_cls: A string indicating a loss class provided by torch.nn or a custom loss class for
the Q-value loss. If it is a string, it must be a key in ``TORCH_LOSS``. Defaults to "mse".
soft_update_coef (float): Soft update coefficient, e.g., target_model = (soft_update_coef) * eval_model +
(1-soft_update_coef) * target_model. Defaults to 1.0.
critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 0.1.
random_overwrite (bool): This specifies overwrite behavior when the replay memory capacity is reached. If True,
overwrite positions will be selected randomly. Otherwise, overwrites will occur sequentially with
wrap-around. Defaults to False.
"""
get_q_critic_net_func: Callable[[], QNet] = None,
reward_discount: float = 0.9
num_epochs: int = 1
update_target_every: int = 5
Expand All @@ -25,6 +42,20 @@ class DDPGParams(TrainerParams):
critic_loss_coef: float = 0.1
random_overwrite: bool = False

def __post_init__(self) -> None:
assert self.get_q_critic_net_func is not None

def extract_ops_params(self) -> Dict[str, object]:
return {
"device": self.device,
"enable_data_parallelism": self.enable_data_parallelism,
"get_q_critic_net_func": self.get_q_critic_net_func,
"reward_discount": self.reward_discount,
"q_value_loss_cls": self.q_value_loss_cls,
"soft_update_coef": self.soft_update_coef,
"critic_loss_coef": self.critic_loss_coef,
}


class DDPGOps(AbsTrainOps):
def __init__(
Expand Down Expand Up @@ -159,65 +190,68 @@ def soft_update_target(self) -> None:
self._target_policy.soft_update(self._policy, self._soft_update_coef)
self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef)

def set_batch(self, batch: TransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch


class DDPG(SingleTrainer):
"""The Deep Deterministic Policy Gradient (DDPG) algorithm.
References:
https://arxiv.org/pdf/1509.02971.pdf
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg
Args:
name (str): Unique identifier for the trainer.
policy_creator (Dict[str, Callable[[str], DiscretePolicyGradient]]): Dict of functions that used to
create policies.
device (str): Identifier for the torch device. The policy will be moved to the specified device. If it is
None, the device will be set to "cpu" if cuda is unavailable and "cuda" otherwise. Defaults to None.
enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. Defaults to False.
replay_memory_capacity (int): Capacity of the replay memory. Defaults to 10000.
random_overwrite (bool): This specifies overwrite behavior when the replay memory capacity is reached. If True,
overwrite positions will be selected randomly. Otherwise, overwrites will occur sequentially with
wrap-around. Defaults to False.
num_epochs (int): Number of training epochs per call to ``learn``. Defaults to 1.
update_target_every (int): Number of training rounds between policy target model updates.
reward_discount (float): Reward decay as defined in standard RL terminology.
q_value_loss_cls: A string indicating a loss class provided by torch.nn or a custom loss class for
the Q-value loss. If it is a string, it must be a key in ``TORCH_LOSS``. Defaults to "mse".
soft_update_coef (float): Soft update coefficient, e.g., target_model = (soft_update_coef) * eval_model +
(1-soft_update_coef) * target_model. Defaults to 1.0.
critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 0.1.
"""

def __init__(self, name: str, params: DDPGParams) -> None:
super(DDPG, self).__init__(name, params)
self._params = params
self._policy_version = self._target_policy_version = 0
self._q_net_version = self._target_q_net_version = 0
self._ops_name = f"{self._name}.ops"

self._replay_memory: Optional[RandomReplayMemory] = None

def build(self) -> None:
self._ops_params = {
"get_policy_func": self._get_policy_func,
**self._params.extract_ops_params(),
}

self._ops = self.get_ops(self._ops_name)

self._replay_memory = RandomReplayMemory(
capacity=self._params.replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim,
random_overwrite=self._params.random_overwrite
)

def record(self, exp_element: ExpElement) -> None:
for agent_name in exp_element.agent_names:
transition_batch = TransitionBatch(
states=np.expand_dims(exp_element.agent_state_dict[agent_name], axis=0),
actions=np.expand_dims(exp_element.action_dict[agent_name], axis=0),
rewards=np.array([exp_element.reward_dict[agent_name]]),
terminals=np.array([exp_element.terminal_dict[agent_name]]),
next_states=np.expand_dims(
exp_element.next_agent_state_dict.get(agent_name, exp_element.agent_state_dict[agent_name]),
axis=0,
),
)
self._replay_memory.put(transition_batch)

def _get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DDPGOps(**self._ops_params)

def _get_batch(self, batch_size: int = None) -> TransitionBatch:
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)

async def train_step(self) -> None:
for _ in range(self._params.num_epochs):
await self._ops.set_batch(self._get_batch())
await self._ops.update()
self._policy_version += 1
if self._policy_version - self._target_policy_version == self._params.update_target_every:
await self._ops.soft_update_target()
self._target_policy_version = self._policy_version

self._policy_version += 1
if self._policy_version - self._target_policy_version == self._params.update_target_every:
await self._ops.soft_update_target()
self._target_policy_version = self._policy_version
4 changes: 4 additions & 0 deletions maro/rl_v3/training/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def update(self) -> None:
def soft_update_target(self) -> None:
self._target_policy.soft_update(self._policy, self._soft_update_coef)

def set_batch(self, batch: TransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch


class DQN(SingleTrainer):
"""The Deep-Q-Networks algorithm.
Expand Down
4 changes: 4 additions & 0 deletions maro/rl_v3/training/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ def soft_update_target(self) -> None:
if not self._shared_critic:
self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef)

def set_batch(self, batch: MultiTransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch


class DiscreteMADDPG(MultiTrainer):
def __init__(self, name: str, params: DiscreteMADDPGParams) -> None:
Expand Down
4 changes: 2 additions & 2 deletions maro/rl_v3/training/train_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def set_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None:
"""Set ops's state."""
raise NotImplementedError

@abstractmethod
def set_batch(self, batch: AbsTransitionBatch) -> None:
assert self._is_valid_transition_batch(batch)
self._batch = batch
raise NotImplementedError

def get_policy_state(self) -> object:
return self._policy.name, self._policy.get_state()
Expand Down
6 changes: 6 additions & 0 deletions maro/rl_v3/utils/transition_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from . import discount_cumsum
from .objects import SHAPE_CHECK_FLAG


Expand All @@ -13,6 +14,8 @@ class TransitionBatch:
rewards: np.ndarray # 1D
next_states: np.ndarray # 2D
terminals: np.ndarray # 1D
returns: np.ndarray = None # 1D
advantages: np.ndarray = None # 1D

def __post_init__(self) -> None:
if SHAPE_CHECK_FLAG:
Expand All @@ -22,6 +25,9 @@ def __post_init__(self) -> None:
assert self.next_states.shape == self.states.shape
assert len(self.terminals.shape) == 1 and self.terminals.shape[0] == self.states.shape[0]

def calc_returns(self, discount_factor: float) -> None:
self.returns = discount_cumsum(self.rewards, discount_factor)


@dataclass
class MultiTransitionBatch:
Expand Down