Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO #272

Merged
merged 86 commits into from
Mar 28, 2023
Merged

PPO #272

Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
dae80b6
Adding initial PPO Code
kshitijkg May 3, 2022
f817ce9
Added buffer sampling and solved some bugs
kshitijkg May 3, 2022
4de8417
ppo agent: device and type errors fixed
sriyash421 May 3, 2022
bc5817e
ppo updater fixed
sriyash421 May 4, 2022
f422b36
ppo config updated
May 4, 2022
d1b7ee2
Updated Hive to use gym spaces instead of raw tuples to represent act…
dapatil211 Apr 20, 2022
81ac508
Updated tests to affect api change of 1106ec2
dapatil211 Apr 20, 2022
e2c1133
Adding initial PPO Code
kshitijkg May 3, 2022
6eb1e4e
Added buffer sampling and solved some bugs
kshitijkg May 3, 2022
2277abd
ppo agent: device and type errors fixed
sriyash421 May 3, 2022
d780683
ppo updater fixed
sriyash421 May 4, 2022
e0cc263
ppo config updated
May 4, 2022
ab5ef03
ppo replay added
sriyash421 May 6, 2022
301e77c
ppo replay conflict
sriyash421 May 6, 2022
6459606
ppo replay fixed
sriyash421 May 6, 2022
babfcce
ppo agent updated
sriyash421 May 6, 2022
c0f9039
ppo agent and config updated
sriyash421 May 6, 2022
a973cb3
ppo code running but buggy
May 9, 2022
0c78e2d
cartpole working
May 12, 2022
f5271bf
ppo configs
May 18, 2022
ad3ed9b
ppo net fixed
May 24, 2022
0f63802
merge dev
May 24, 2022
f598433
atari configs added
May 24, 2022
a3c3c1c
ppo_nets done
May 24, 2022
47e106f
ppo_replay done
May 24, 2022
12a4b73
ppo env wrappers added
May 24, 2022
277b9bc
ppo agent done
May 24, 2022
9417e05
configs done
May 24, 2022
d3616a7
stack size > 1 handled temporarily
May 27, 2022
9de6e3a
linting fixed
sriyash421 May 27, 2022
c151b7b
Merge branch 'dev' into ppo_spaces
sriyash421 Jun 27, 2022
7201e97
last batch drop fix
sriyash421 Jun 27, 2022
45a9da4
config changes
sriyash421 Jun 29, 2022
4ea7527
Merge branch 'ppo_spaces' of github.com:chandar-lab/RLHive into ppo_s…
sriyash421 Jun 29, 2022
a9848e1
shared network added
sriyash421 Jul 7, 2022
3adcb73
Merge branch 'dev' into ppo_spaces
sriyash421 Jul 7, 2022
54996f3
reward wrapper added
sriyash421 Jul 13, 2022
fa9297b
linting fixed
sriyash421 Jul 13, 2022
2ac73ba
Merge branch 'dev' into ppo_spaces
sriyash421 Jul 13, 2022
3d11136
Merge branch 'dev' into ppo_spaces
sriyash421 Jul 29, 2022
4f5b4b8
docs fixed
sriyash421 Aug 28, 2022
cc4aab8
replay changed
sriyash421 Aug 28, 2022
2d55afa
update loop
sriyash421 Aug 28, 2022
dcf2aa7
type specification
sriyash421 Aug 28, 2022
936c3b1
env wrappers registered
sriyash421 Aug 28, 2022
04d8692
linting fixed
sriyash421 Aug 29, 2022
360dc00
Merge branch 'dev' into ppo_spaces
kshitijkg Sep 25, 2022
b1613cd
Removed one off transition, cleaned up replay buffer
kshitijkg Sep 25, 2022
bdcd11e
Fixed linter issues
kshitijkg Sep 25, 2022
5a8e2da
wrapper error fixed
sriyash421 Sep 29, 2022
a54377a
added vars to dict; fixed long lines and var names; moved wrapper reg…
sriyash421 Oct 11, 2022
9680185
config fixed
sriyash421 Oct 13, 2022
2c9295f
addded normalisation and fixed log
sriyash421 Oct 13, 2022
767f96c
norm filed added
sriyash421 Oct 14, 2022
b4f2ea1
norm bug fixed
sriyash421 Nov 3, 2022
58f5ec2
rew norm updated
sriyash421 Nov 11, 2022
306faea
fixes
sriyash421 Nov 11, 2022
35d6aeb
fixing norm bug; config
sriyash421 Nov 23, 2022
7d31faf
config fixes
sriyash421 Nov 23, 2022
b84722e
obs norm
sriyash421 Nov 24, 2022
a4c1692
hardcoded wrappers added
sriyash421 Nov 24, 2022
11ccb21
normaliser shape fixed
sriyash421 Dec 6, 2022
0991e84
rew shape fixed; norm structure updated
sriyash421 Dec 6, 2022
c7f42a1
rew norm
sriyash421 Dec 6, 2022
84d933e
configs and wrapper fixed
sriyash421 Dec 7, 2022
3f01532
merge dev
sriyash421 Dec 19, 2022
54799c2
Merge branch 'dev' into ppo_spaces
sriyash421 Dec 19, 2022
8fb9902
Fixed formatting and naming
kshitijkg Jan 30, 2023
bd5c587
Added env wrapper logic
kshitijkg Jan 30, 2023
697a78c
Merging dev
kshitijkg Jan 30, 2023
a1e77fa
Renamed PPO Replay Buffer to On Policy Replay buffer
kshitijkg Jan 30, 2023
031f462
Made PPO Stateless Agent
kshitijkg Jan 30, 2023
28733ec
Fixed linting issues
kshitijkg Jan 30, 2023
8885a89
Minor modifications
kshitijkg Feb 7, 2023
0e42146
Fixed changed
kshitijkg Feb 8, 2023
d785c85
Formatting and minor changes
kshitijkg Mar 2, 2023
4946874
Merge branch 'dev' into ppo_spaces
dapatil211 Mar 20, 2023
308f111
Refactored Advatange Computation
kshitijkg Mar 21, 2023
543fc74
Reformating with black
kshitijkg Mar 21, 2023
43c3fb1
Renaming
kshitijkg Mar 21, 2023
4d82f99
Refactored Normalization code
kshitijkg Mar 21, 2023
e7d08d5
Added saving and loading of state dict for normalizers
kshitijkg Mar 21, 2023
aba7c49
Fixed multiplayer replay buffer for PPO
kshitijkg Mar 21, 2023
000c4e4
Fixed minor bug
kshitijkg Mar 22, 2023
3d6d076
Renamed file
kshitijkg Mar 22, 2023
aabeed0
Added lr annealing
dapatil211 Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hive/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hive.agents.dqn import DQNAgent
from hive.agents.drqn import DRQNAgent
from hive.agents.legal_moves_rainbow import LegalMovesRainbowAgent
from hive.agents.ppo import PPOAgent
from hive.agents.rainbow import RainbowDQNAgent
from hive.agents.random import RandomAgent
from hive.agents.td3 import TD3
Expand All @@ -16,6 +17,7 @@
"DQNAgent": DQNAgent,
"DRQNAgent": DRQNAgent,
"LegalMovesRainbowAgent": LegalMovesRainbowAgent,
"PPOAgent": PPOAgent,
"RainbowDQNAgent": RainbowDQNAgent,
"RandomAgent": RandomAgent,
"TD3": TD3,
Expand Down
405 changes: 405 additions & 0 deletions hive/agents/ppo.py

Large diffs are not rendered by default.

151 changes: 151 additions & 0 deletions hive/agents/qnets/normalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Tuple

import numpy as np

from hive.utils.registry import Registrable, registry

# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class MeanStd:
"""Tracks the mean, variance and count of values."""

# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
"""Tracks the mean, variance and count of values."""
self.mean = np.zeros(shape, "float64")
self.var = np.ones(shape, "float64")
self.count = epsilon

def update(self, x):
"""Updates the mean, var and count from a batch of samples."""
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)

def update_from_moments(self, batch_mean, batch_var, batch_count):
"""Updates from batch mean, variance and count moments."""
self.mean, self.var, self.count = self.update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)

def update_mean_var_count_from_moments(
self, mean, var, count, batch_mean, batch_var, batch_count
):
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
delta = batch_mean - mean
tot_count = count + batch_count

new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count

return new_mean, new_var, new_count


class BaseNormalizationFn(object):
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
"""Implements the base normalization function."""

def __init__(self, *args, **kwds):
pass

def __call__(self, *args, **kwds):
return NotImplementedError

def update(self, *args, **kwds):
return NotImplementedError


class ObservationNormalizationFn(BaseNormalizationFn):
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
"""Implements a normalization function. Transforms output by
normalising the input data by the running :obj:`mean` and
:obj:`std`, and clipping the normalised data on :obj:`clip`
"""

def __init__(
self, shape: Tuple[int, ...], epsilon: float = 1e-4, clip: np.float32 = np.inf
):
"""
Args:
epsilon (float): minimum value of variance to avoid division by 0.
shape (tuple[int]): The shape of input data.
clip (np.float32): The clip value for the normalised data.
"""
super().__init__()
self.obs_rms = MeanStd(epsilon, shape)
self._shape = shape
self._epsilon = epsilon
self._clip = clip

def __call__(self, obs):
obs = np.array([obs])
obs = ((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self._epsilon))[0]
if self._clip is not None:
obs = np.clip(obs, -self._clip, self._clip)
return obs

def update(self, obs):
self.obs_rms.update(obs)


class RewardNormalizationFn(BaseNormalizationFn):
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
"""Implements a normalization function. Transforms output by
normalising the input data by the running :obj:`mean` and
:obj:`std`, and clipping the normalised data on :obj:`clip`
"""
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, gamma: float, epsilon: float = 1e-4, clip: np.float32 = np.inf):
"""
Args:
gamma (float): discount factor for the agent.
epsilon (float): minimum value of variance to avoid division by 0.
clip (np.float32): The clip value for the normalised data.
"""
super().__init__()
self.return_rms = MeanStd(epsilon, ())
self._epsilon = epsilon
self._clip = clip
self._gamma = gamma
self._returns = np.zeros(1)

def __call__(self, rew):
rew = np.array([rew])
rew = (rew / np.sqrt(self.return_rms.var + self._epsilon))[0]
if self._clip is not None:
rew = np.clip(rew, -self._clip, self._clip)
return rew

def update(self, rew, done):
self._returns = self._returns * self._gamma + rew
self.return_rms.update(self._returns)
self._returns *= 1 - done


class NormalizationFn(Registrable):
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
"""A wrapper for callables that produce normalization functions.

These wrapped callables can be partially initialized through configuration
files or command line arguments.
"""

@classmethod
def type_name(cls):
"""
Returns:
"norm_fn"
"""
return "norm_fn"


registry.register_all(
NormalizationFn,
{
"BaseNormalization": BaseNormalizationFn,
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
"RewardNormalization": RewardNormalizationFn,
"ObservationNormalization": ObservationNormalizationFn,
},
)

get_norm_fn = getattr(registry, f"get_{NormalizationFn.type_name()}")
116 changes: 116 additions & 0 deletions hive/agents/qnets/ppo_nets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Tuple, Union
import gym
from gym.spaces import Box, Discrete
import numpy as np
import torch

from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.utils import calculate_output_dim


class CategoricalHead(torch.nn.Module):
"""A module that implements a discrete actor head. It uses the ouput from the
:obj:`actor_net`, and adds creates a :py:class:`~torch.distributions.categorical.Categorical`
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
object to compute the action distribution."""

def __init__(
self, feature_dim: Tuple[int], action_space: gym.spaces.Discrete
) -> None:
"""
Args:
feature dim: Expected output shape of the actor network.
action_shape: Expected shape of actions.
"""
super().__init__()
self.network = torch.nn.Linear(feature_dim, action_space.n)
self.distribution = torch.distributions.categorical.Categorical

def forward(self, x):
logits = self.network(x)
return self.distribution(logits=logits)


class GaussianPolicyHead(torch.nn.Module):
"""A module that implements a continuous actor head. It uses the output from the
:obj:`actor_net` and state independent learnable parameter :obj:`policy_logstd` to
create a :py:class:`~torch.distributions.normal.Normal` object to compute
the action distribution."""

def __init__(self, feature_dim: Tuple[int], action_space: gym.spaces.Box) -> None:
"""
Args:
feature dim: Expected output shape of the actor network.
action_shape: Expected shape of actions.
"""
super().__init__()
self._action_shape = action_space.shape
self.policy_mean = torch.nn.Sequential(
torch.nn.Linear(feature_dim, np.prod(self._action_shape))
)
self.policy_logstd = torch.nn.Parameter(
torch.zeros(1, np.prod(action_space.shape))
)
self.distribution = torch.distributions.normal.Normal

def forward(self, x):
_mean = self.policy_mean(x)
_std = self.policy_logstd.repeat(x.shape[0], 1).exp()
distribution = self.distribution(
torch.reshape(_mean, (x.size(0), *self._action_shape)),
torch.reshape(_std, (x.size(0), *self._action_shape)),
)
return distribution


class PPOActorCriticNetwork(torch.nn.Module):
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
"""A module that implements the PPO actor and critic computation. It puts together the
:obj:`representation_network`, :obj:`actor_net` and :obj:`critic_net`, then adds two final
:py:class:`~torch.nn.Linear` layers to compute the action and state value."""

kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
representation_network: torch.nn.Module,
actor_net: FunctionApproximator,
critic_net: FunctionApproximator,
network_output_dim: Union[int, Tuple[int]],
action_space: Union[Box, Discrete],
continuous_action: bool,
) -> None:
super().__init__()
self._network = representation_network
self._continuous_action = continuous_action
if actor_net is None:
actor_network = torch.nn.Identity()
else:
actor_network = actor_net(network_output_dim)
feature_dim = np.prod(calculate_output_dim(actor_network, network_output_dim))
actor_head = GaussianPolicyHead if self._continuous_action else CategoricalHead

self.actor = torch.nn.Sequential(
actor_network,
torch.nn.Flatten(),
actor_head(feature_dim, action_space),
)

if critic_net is None:
critic_network = torch.nn.Identity()
else:
critic_network = critic_net(network_output_dim)
feature_dim = np.prod(calculate_output_dim(critic_network, network_output_dim))
self.critic = torch.nn.Sequential(
critic_network,
torch.nn.Flatten(),
torch.nn.Linear(feature_dim, 1),
)

def forward(self, x, action=None):
hidden_state = self._network(x)
distribution = self.actor(hidden_state)
value = self.critic(hidden_state)
if action is None:
action = distribution.sample()

logprob, entropy = distribution.log_prob(action), distribution.entropy()
if self._continuous_action:
logprob, entropy = logprob.sum(dim=-1), entropy.sum(dim=-1)
return action, logprob, entropy, value
2 changes: 2 additions & 0 deletions hive/agents/qnets/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
from typing import Tuple

import numpy as np
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
import torch

from hive.utils.registry import registry
Expand Down
64 changes: 64 additions & 0 deletions hive/configs/atari/ppo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
run_name: &run_name 'atari-ppo'
train_steps: 10000000
test_frequency: 250000
test_episodes: 10
max_steps_per_episode: 27000
stack_size: &stack_size 4
save_dir: 'experiment'
saving_schedule:
name: 'PeriodicSchedule'
kwargs:
off_value: False
on_value: True
period: 1000000
environment:
name: 'AtariEnv'
kwargs:
env_name: 'Breakout'

agent:
name: 'PPOAgent'
kwargs:
representation_net:
name: 'ConvNetwork'
kwargs:
channels: [32, 64, 64]
kernel_sizes: [8, 4, 3]
strides: [4, 2, 1]
paddings: [2, 2, 1]
mlp_layers: [512]
optimizer_fn:
name: 'Adam'
kwargs:
lr: .00025
init_fn:
name: 'orthogonal'
replay_buffer:
name: 'PPOReplayBuffer'
kwargs:
stack_size: *stack_size
use_gae: True
gae_lambda: .95
discount_rate: .99
grad_clip: .5
clip_coef: .1
ent_coef: .0
clip_vloss: True
vf_coef: .5
transitions_per_update: 4096
num_epochs_per_update: 4
normalize_advantages: True
batch_size: 256
device: 'cuda'
id: 'agent'
# List of logger configs used.
loggers:
-
name: ChompLogger
-
name: WandbLogger
kwargs:
project: Hive
name: *run_name
resume: "allow"
start_method: "fork"
Loading