Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example: Simple RL example using DQN/Lightning #1232

Merged
merged 36 commits into from
Mar 28, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
07a4e7d
Example: Simple RL example using DQN/Lightning
Mar 22, 2020
05cf5ac
Applied autopep8 fixes
Mar 23, 2020
fc9f31d
* Updated line length from 120 to 110
Mar 25, 2020
cafea47
Update pl_examples/domain_templates/dqn.py
djbyrne Mar 25, 2020
606f1f2
Update pl_examples/domain_templates/dqn.py
djbyrne Mar 25, 2020
45d671a
CI: split tests-examples (#990)
Borda Mar 25, 2020
31ef2eb
Clean up
Mar 25, 2020
e86e6b2
updated example image
williamFalcon Mar 26, 2020
d2ef4fa
update types
Borda Mar 26, 2020
b4b8dd7
rename script
Borda Mar 26, 2020
3255539
Update CHANGELOG.md
djbyrne Mar 26, 2020
8b2c9e2
another rename
Borda Mar 26, 2020
2a4cd47
Disable validation when val_percent_check=0 (#1251)
Mar 27, 2020
d394b80
calling self.forward() -> self() (#1211)
jeremyjordan Mar 27, 2020
6a0b171
Fix requirements-extra.txt Trains package to release version (#1229)
bmartinn Mar 27, 2020
6772e0c
Remove unnecessary parameters to super() in documentation and source …
TylerYep Mar 27, 2020
593bf50
update deprecation warning (#1258)
Borda Mar 27, 2020
bec43c9
update docs for progress bat values (#1253)
Borda Mar 27, 2020
9bb2e00
lower timeouts for inactive issues (#1250)
Borda Mar 27, 2020
da18534
update contrib list (#1241)
Borda Mar 27, 2020
3a93aaf
Fix outdated docs (#1227)
5n7-sk Mar 27, 2020
ac6692d
Fix typo (#1224)
5n7-sk Mar 27, 2020
1a9719c
drop unused Tox (#1242)
Borda Mar 27, 2020
61177cd
system info (#1234)
Borda Mar 27, 2020
12b39a7
Changed smoothing in tqdm to decrease variability of time remaining b…
pertschuk Mar 27, 2020
582fe4c
Example: Simple RL example using DQN/Lightning
Mar 22, 2020
3ed2739
Applied autopep8 fixes
Mar 23, 2020
dac522e
* Updated line length from 120 to 110
Mar 25, 2020
ec85171
Update pl_examples/domain_templates/dqn.py
djbyrne Mar 25, 2020
eb72022
Update pl_examples/domain_templates/dqn.py
djbyrne Mar 25, 2020
e03e015
Clean up
Mar 25, 2020
0e9ca89
update types
Borda Mar 26, 2020
42838b3
rename script
Borda Mar 26, 2020
9cf4915
Update CHANGELOG.md
djbyrne Mar 26, 2020
f9be8b0
another rename
Borda Mar 26, 2020
a44c90c
Merge branch 'dqn_example' of https://github.com/djbyrne/pytorch-ligh…
Mar 28, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added dqn lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
djbyrne marked this conversation as resolved.
Show resolved Hide resolved
- Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152))
- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122))
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
Expand Down
358 changes: 358 additions & 0 deletions pl_examples/domain_templates/dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
"""
this example is based off https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-
Second-Edition/blob/master/Chapter06/02_dqn_pong.py

The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the
classic CartPole environment.

to run the template just run:
python dqn.py

After ~1500 steps, you will see the total_reward hitting the max score of 200. Open up tensor boards to
see the metrics.

tensorboard --logdir default
"""

import pytorch_lightning as pl

from typing import Tuple, List

import argparse
from collections import OrderedDict, deque, namedtuple

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset


class DQN(nn.Module):
"""
Simple MLP network

Args:
obs_size: observation/state size of the environment
n_actions: number of discrete actions available in the environment
hidden_size: size of hidden layers
"""

def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
super(DQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions)
)

def forward(self, x):
return self.net(x.float())


# Named tuple for storing experience steps gathered in training
Experience = namedtuple(
'Experience', field_names=['state', 'action', 'reward',
'done', 'new_state'])


class ReplayBuffer:
"""
Replay Buffer for storing past experiences allowing the agent to learn from them

Args:
capacity: size of the buffer
"""

def __init__(self, capacity: int) -> None:
self.buffer = deque(maxlen=capacity)

def __len__(self) -> None:
return len(self.buffer)

def append(self, experience: Experience) -> None:
"""
Add experience to the buffer

Args:
experience: tuple (state, action, reward, done, new_state)
"""
self.buffer.append(experience)

def sample(self, batch_size: int) -> Tuple:
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])

return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
np.array(dones, dtype=np.bool), np.array(next_states))


class RLDataset(IterableDataset):
"""
Iterable Dataset containing the ExperienceBuffer
which will be updated with new experiences during training

Args:
buffer: replay buffer
sample_size: number of experiences to sample at a time
"""

def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
self.buffer = buffer
self.sample_size = sample_size

def __iter__(self) -> Tuple:
states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
for i in range(len(dones)):
yield states[i], actions[i], rewards[i], dones[i], new_states[i]


class Agent:
"""
Base Agent class handeling the interaction with the environment

Args:
env: training environment
replay_buffer: replay buffer storing experiences
"""

def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
self.env = env
self.replay_buffer = replay_buffer
self.reset()
self.state = self.env.reset()

def reset(self) -> None:
""" Resents the environment and updates the state"""
self.state = self.env.reset()

def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
"""
Using the given network, decide what action to carry out
using an epsilon-greedy policy

Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device

Returns:
action
"""
if np.random.random() < epsilon:
action = self.env.action_space.sample()
else:
state = torch.tensor([self.state])

if device not in ['cpu']:
state = state.cuda(device)

q_values = net(state)
_, action = torch.max(q_values, dim=1)
action = int(action.item())

return action

@torch.no_grad()
def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]:
"""
Carries out a single interaction step between the agent and the environment

Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device

Returns:
reward, done
"""

action = self.get_action(net, epsilon, device)

# do step in the environment
new_state, reward, done, _ = self.env.step(action)

exp = Experience(self.state, action, reward, done, new_state)

self.replay_buffer.append(exp)

self.state = new_state
if done:
self.reset()
return reward, done


class DQNLightning(pl.LightningModule):
""" Basic DQN Model """

def __init__(self, hparams: argparse.Namespace) -> None:
super().__init__()
self.hparams = hparams

self.env = gym.make(self.hparams.env)
obs_size = self.env.observation_space.shape[0]
n_actions = self.env.action_space.n

self.net = DQN(obs_size, n_actions)
self.target_net = DQN(obs_size, n_actions)

self.buffer = ReplayBuffer(self.hparams.replay_size)
self.agent = Agent(self.env, self.buffer)
self.total_reward = 0
self.episode_reward = 0
self.populate(self.hparams.warm_start_steps)

def populate(self, steps: int = 1000) -> None:
"""
Carries out several random steps through the environment to initially fill
up the replay buffer with experiences

Args:
steps: number of random steps to populate the buffer with
"""
for i in range(steps):
self.agent.play_step(self.net, epsilon=1.0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Passes in a state x through the network and gets the q_values of each action as an output

Args:
x: environment state

Returns:
q values
"""
output = self.net(x)
return output

def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""
Calculates the mse loss using a mini batch from the replay buffer

Args:
batch: current mini batch of replay data

Returns:
loss
"""
states, actions, rewards, dones, next_states = batch

state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

with torch.no_grad():
next_state_values = self.target_net(next_states).max(1)[0]
next_state_values[dones] = 0.0
next_state_values = next_state_values.detach()

expected_state_action_values = next_state_values * self.hparams.gamma + rewards

return nn.MSELoss()(state_action_values, expected_state_action_values)

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
"""
Carries out a single step through the environment to update the replay buffer.
Then calculates loss based on the minibatch recieved

Args:
batch: current mini batch of replay data
nb_batch: batch number

Returns:
Training loss and log metrics
"""
device = self.get_device(batch)
epsilon = max(self.hparams.eps_end, self.hparams.eps_start -
self.global_step + 1 / self.hparams.eps_last_frame)

# step through environment with agent
reward, done = self.agent.play_step(self.net, epsilon, device)
self.episode_reward += reward

# calculates training loss
loss = self.dqn_mse_loss(batch)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)

if done:
self.total_reward = self.episode_reward
self.episode_reward = 0

# Soft update of target network
if self.global_step % self.hparams.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())

log = {'total_reward': torch.tensor(self.total_reward).to(device),
'reward': torch.tensor(reward).to(device),
'steps': torch.tensor(self.global_step).to(device)}

return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log})

def configure_optimizers(self) -> List[Optimizer]:
""" Initialize Adam optimizer"""
optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr)
return [optimizer]

def __dataloader(self) -> DataLoader:
"""Initialize the Replay Buffer dataset used for retrieving experiences"""
dataset = RLDataset(self.buffer, self.hparams.episode_length)
dataloader = DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
sampler=None
)
return dataloader

def train_dataloader(self) -> DataLoader:
"""Get train loader"""
return self.__dataloader()

def get_device(self, batch) -> str:
"""Retrieve device currently being used by minibatch"""
return batch[0].device.index if self.on_gpu else 'cpu'


def main(hparams) -> None:
model = DQNLightning(hparams)

trainer = pl.Trainer(
gpus=1,
distributed_backend='dp',
early_stop_callback=False,
val_check_interval=100
)

trainer.fit(model)


if __name__ == '__main__':
torch.manual_seed(0)
np.random.seed(0)

parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag")
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
parser.add_argument("--sync_rate", type=int, default=10,
help="how many frames do we update the target network")
parser.add_argument("--replay_size", type=int, default=1000,
help="capacity of the replay buffer")
parser.add_argument("--warm_start_size", type=int, default=1000,
help="how many samples do we use to fill our buffer at the start of training")
parser.add_argument("--eps_last_frame", type=int, default=1000,
help="what frame should epsilon stop decaying")
parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
parser.add_argument("--max_episode_reward", type=int, default=200,
help="max episode reward in the environment")
parser.add_argument("--warm_start_steps", type=int, default=1000,
help="max episode reward in the environment")

args = parser.parse_args()

main(args)
3 changes: 2 additions & 1 deletion pl_examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torchvision>=0.4.0
torchvision>=0.4.0
gym>=0.17.0