Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GAIL algorithm #315

Merged
merged 40 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f381e7b
[IBR-2091] Add gail algorithm
isk03276 Jun 21, 2021
9f03a1d
[IBR-2068] Modify standard deviation of gaussian action in ppo
isk03276 Jun 21, 2021
7276000
Merge branch 'hotfix/improve_ppo_algorithm' into feature/add_gail_alg…
isk03276 Jun 23, 2021
bab3893
[IBR-2091] Improve gail algorithm
isk03276 Jun 23, 2021
2a8d24f
[IBR-2068] Add ppo algorithm for discrete action
isk03276 Jun 24, 2021
3ede275
[IBR-2068] Add shared backbone for actor critic
isk03276 Jun 25, 2021
ac1d8dc
[IBR-2068] Fix gpu oom bug
isk03276 Jun 25, 2021
b7373c5
[IBR-2068] Tuning hyper-parameters for ppo
isk03276 Jun 28, 2021
1c254a3
[IBR-2068] Modify multi env
isk03276 Jun 28, 2021
fcb648d
Merge branch 'hotfix/improve_ppo_algorithm' into feature/add_gail_alg…
isk03276 Jun 28, 2021
9747ee7
[IBR-2091] Modify input size of discriminator network
isk03276 Jun 28, 2021
3ebaf46
[IBR-2068] Modify learner for shared actor critic
isk03276 Jun 28, 2021
aeaf96a
Merge branch 'hotfix/improve_ppo_algorithm' into feature/add_gail_alg…
isk03276 Jun 28, 2021
e2cfd2a
[IBR-2091] Add forward_backbone and forward_head function
isk03276 Jul 1, 2021
557f0a2
[IBR-2091] Change threshold for determining discriminator accuracy
isk03276 Jul 1, 2021
8856de0
[IBR-2068] Rollback ppo config
isk03276 Jul 1, 2021
d05cc0a
[IBR-2068] Merge with master branch
isk03276 Jul 5, 2021
50f76aa
Merge branch 'hotfix/improve_ppo_algorithm' into feature/add_gail_alg…
isk03276 Jul 5, 2021
8cc7d05
[IBR-2068] Add ppo with discrete action
isk03276 Jul 7, 2021
f64b5f2
Merge branch 'hotfix/improve_ppo_algorithm' into feature/add_gail_alg…
isk03276 Jul 7, 2021
1bf052d
[IBR-2068]Remove retain_graph option
isk03276 Jul 8, 2021
1eee287
Merge branch 'hotfix/improve_ppo_algorithm' into feature/add_gail_alg…
isk03276 Jul 8, 2021
850542a
[IBR-2097] Remove retain_graph option
isk03276 Jul 8, 2021
7ded89e
[IBR-2091] Add discriminator class
isk03276 Jul 12, 2021
efaf298
Merge branch 'master' of https://github.com/medipixel/rl_algorithms i…
isk03276 Jul 13, 2021
61efa98
[IBR-2091] Modify action embedder config
isk03276 Jul 13, 2021
356cba4
[IBR-2091] Modify/Add comments
isk03276 Jul 13, 2021
294837d
[IBR-2091] Modify pylint
isk03276 Jul 13, 2021
8eaeb2e
[IBR-2091] Convet action type to numpy array in select_action function
isk03276 Jul 20, 2021
eadd1ff
[IBR-2069] Modify hidden activation function
isk03276 Jul 28, 2021
7d5ac5a
Merge branch 'feature/modify_hidden_activation' into feature/add_gail…
isk03276 Jul 28, 2021
7913708
Merge branch 'master' into feature/add_gail_algorithm
jiseongHAN Aug 19, 2021
5daf253
[IBR-2097] Modify readme
isk03276 Aug 23, 2021
922222b
Merge branch 'feature/add_gail_algorithm' of https://github.com/medip…
isk03276 Aug 23, 2021
7651f83
Merge branch 'master' into feature/add_gail_algorithm
isk03276 Aug 23, 2021
28365c3
[IBR-2097] Modify readme
isk03276 Aug 23, 2021
d91bbab
[IBR-2069] Modify readme file
isk03276 Aug 23, 2021
beaa4ad
[IBR-2097] Modfiy readme
isk03276 Aug 23, 2021
0832e7b
Update version 1.1.0 to 1.2.0
jiseongHAN Aug 23, 2021
eba07fd
update ray to 1.3.0
jiseongHAN Aug 23, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ This project follows the [all-contributors](https://github.com/all-contributors/
10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent)
11. [Distributed Pioritized Experience Replay (Ape-X)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/common/apex)
12. [Policy Distillation](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/distillation)
13. [Generative Adversarial Imitation Learning (GAIL)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/gail)

## Performance

Expand Down Expand Up @@ -139,6 +140,14 @@ See <a href="https://app.wandb.ai/medipixel_rl/LunarLanderContinuous-v2/reports/
</p>
</details>

<details><summary><b>LunarLanderContinuous-v2: PPO, SAC, GAIL</b></summary>
<p><br>
See <a href="https://wandb.ai/chaehyeuk-lee/LunarLanderContinuous-v2?workspace=user-chaehyeuk-lee">W&B log</a> for more details. (The performance is measured on the commit <a href="https://github.com/medipixel/rl_algorithms/commit/922222b2e249f1f14bdf1a28c9f0f00752e49907">9e897ad</a>)

![lunarlandercontinuous-v2_gail](https://user-images.githubusercontent.com/23740495/130401442-8b668975-8760-4a79-b757-1c1e9a9c4e47.png)
</p>
</details>

#### Reacher-v2

We reproduced the performance of **DDPG**, **TD3**, and **SAC** on Reacher-v2 (Mujoco). They reach the score around -3.5 to -4.5.
Expand Down Expand Up @@ -313,3 +322,4 @@ To cite this repository in publications:
19. [Steven Kapturowski et al., "Recurrent Experience Replay in Distributed Reinforcement Learning." in International Conference on Learning Representations https://openreview.net/forum?id=r1lyTjAqYX, 2019.](https://openreview.net/forum?id=r1lyTjAqYX)
20. [Horgan et al., "Distributed Prioritized Experience Replay." in International Conference on Learning Representations, 2018](https://arxiv.org/pdf/1803.00933.pdf)
21. [Simonyan et al., "Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps", 2013](https://arxiv.org/pdf/1312.6034.pdf)
22. [Ho et al., "Generative adversarial imitation learning", 2016](https://arxiv.org/abs/1606.03476)
59 changes: 59 additions & 0 deletions configs/lunarlander_continuous_v2/gail_ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
type: "GAILPPOAgent"
hyper_params:
gamma: 0.99
tau: 0.95
batch_size: 128
max_epsilon: 0.2
min_epsilon: 0.2
epsilon_decay_period: 1500
w_value: 1.0
w_entropy: 0.001
gradient_clip_ac: 0.5
gradient_clip_cr: 1.0
epoch: 10
rollout_len: 1024
n_workers: 4
use_clipped_value_loss: False
standardize_advantage: True
gail_reward_weight: 1.0
demo_path: "data/lunarlander_continuous_demo.pkl"

learner_cfg:
type: "GAILPPOLearner"
backbone:
actor:
critic:
discriminator:
shared_actor_critic:
head:
actor:
type: "GaussianDist"
configs:
hidden_sizes: [256, 256]
output_activation: "identity"
fixed_logstd: True
critic:
type: "MLP"
configs:
hidden_sizes: [256, 256]
output_size: 1
output_activation: "identity"
discriminator:
type: "MLP"
configs:
hidden_sizes: [256, 256]
output_size: 1
output_activation: "identity"
aciton_embedder:
type: "MLP"
configs:
hidden_sizes: []
output_size: 16
output_activation: "identity"

optim_cfg:
lr_actor: 0.0003
lr_critic: 0.001
lr_discriminator: 0.0003
weight_decay: 0.0
discriminator_acc_threshold : 0.8
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ tqdm

# for distributed learning
redis==3.5.3 # for ray
ray==1.2.0
ray==1.3.0
pyzmq==20.0.0
pyarrow==3.0.0

Expand Down
4 changes: 4 additions & 0 deletions rl_algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .fd.dqn_learner import DQfDLearner
from .fd.sac_agent import SACfDAgent
from .fd.sac_learner import SACfDLearner
from .gail.agent import GAILPPOAgent
from .gail.learner import GAILPPOLearner
from .ppo.agent import PPOAgent
from .ppo.learner import PPOLearner
from .recurrent.dqn_agent import R2D1Agent
Expand All @@ -45,6 +47,7 @@
"PPOAgent",
"SACAgent",
"TD3Agent",
"GAILPPOAgent",
"A2CLearner",
"BCDDPGLearner",
"BCSACLearner",
Expand All @@ -56,6 +59,7 @@
"PPOLearner",
"SACLearner",
"TD3Learner",
"GAILPPOLearner",
"R2D1Learner",
"LunarLanderContinuousHER",
"ReacherHER",
Expand Down
60 changes: 60 additions & 0 deletions rl_algorithms/common/buffer/gail_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
"""Demo buffer for GAIL algorithm."""

import pickle
from typing import List, Tuple

import numpy as np
import torch

from rl_algorithms.common.abstract.buffer import BaseBuffer

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class GAILBuffer(BaseBuffer):
"""Buffer to store expert states and actions.

Attributes:
obs_buf (np.ndarray): observations
acts_buf (np.ndarray): actions
"""

def __init__(self, dataset_path: str):
"""Initialize a Buffer.

Args:
dataset_path (str): path of the demo dataset
"""

self.obs_buf: np.ndarray = None
self.acts_buf: np.ndarray = None

self.load_demo(dataset_path)

def load_demo(self, dataset_path: str):
"""load demo data."""
with open(dataset_path, "rb") as f:
demo = list(pickle.load(f))
demo = np.array(demo)
self.obs_buf = np.array(list(map(np.array, demo[:, 0])))
self.acts_buf = np.array(list(map(np.array, demo[:, 1])))

def add(self):
pass

def sample(self, batch_size, indices: List[int] = None) -> Tuple[np.ndarray, ...]:
"""Randomly sample a batch of experiences from memory."""
assert 0 < batch_size < len(self)

if indices is None:
indices = np.random.choice(len(self), size=batch_size)

states = self.obs_buf[indices]
actions = self.acts_buf[indices]

return torch.Tensor(states).to(device), torch.Tensor(actions).to(device)

def __len__(self) -> int:
"""Return the current size of internal memory."""
return len(self.obs_buf)
1 change: 1 addition & 0 deletions rl_algorithms/gail/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Empty."""
Loading