Skip to content

Commit

Permalink
Merge branch 'master' into feature/remove_abstract
Browse files Browse the repository at this point in the history
  • Loading branch information
darthegg authored Apr 10, 2019
2 parents 17ef9a7 + ca729e0 commit d79c356
Show file tree
Hide file tree
Showing 15 changed files with 301 additions and 71 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ We are warmly welcoming external contributors! :)
6. [Behaviour Cloning (BC with DDPG, SAC)](https://github.com/medipixel/rl_algorithms/tree/master/algorithms/bc)
7. [Prioritized Experience Replay (PER with DDPG)](https://github.com/medipixel/rl_algorithms/tree/master/algorithms/per)
8. [From Demonstrations (DDPGfD, SACfD, DQfD)](https://github.com/medipixel/rl_algorithms/tree/master/algorithms/fd)
9. [Rainbow DQN (without NoisyNet)](https://github.com/medipixel/rl_algorithms/tree/master/algorithms/dqn)
10. [Rainbow IQN (without DuelingNet & NoisyNet)](https://github.com/medipixel/rl_algorithms/tree/master/algorithms/dqn)
9. [Rainbow DQN](https://github.com/medipixel/rl_algorithms/tree/master/algorithms/dqn)
10. [Rainbow IQN (without DuelingNet)](https://github.com/medipixel/rl_algorithms/tree/master/algorithms/dqn) - DuelingNet [degrades performance](https://github.com/medipixel/rl_algorithms/pull/137)

## Getting started
We have tested each algorithm on some of the following environments.
Expand Down Expand Up @@ -109,6 +109,10 @@ python <run-file> -h
- `--load-from <save-file-path>`
- Load the saved models and optimizers at the beginning.

### Class Diagram
Class diagram drawn on [e447f3e](https://github.com/medipixel/rl_algorithms/commit/e447f3e743f6f85505f2275b646e46f0adcf8f89). This will be not frequently updated.
![rl_algorithms_cls](https://user-images.githubusercontent.com/14961526/55703648-26022a80-5a15-11e9-8099-9bbfdffcb96d.png)

### W&B for logging
We use [W&B](https://www.wandb.com/) for logging of network parameters and others. For more details, read [W&B tutorial](https://docs.wandb.com/docs/started.html).

Expand All @@ -128,5 +132,6 @@ We use [W&B](https://www.wandb.com/) for logging of network parameters and other
12. [Z. Wang et al., "Dueling Network Architectures for Deep Reinforcement Learning." arXiv preprint arXiv:1511.06581, 2015.](https://arxiv.org/pdf/1511.06581.pdf)
13. [T. Hester et al., "Deep Q-learning from Demonstrations." arXiv preprint arXiv:1704.03732, 2017.](https://arxiv.org/pdf/1704.03732.pdf)
14. [M. G. Bellemare et al., "A Distributional Perspective on Reinforcement Learning." arXiv preprint arXiv:1707.06887, 2017.](https://arxiv.org/pdf/1707.06887.pdf)
15. [M. Hessel et al., "Rainbow: Combining Improvements in Deep Reinforcement Learning." arXiv preprint arXiv:1710.02298, 2017.](https://arxiv.org/pdf/1710.02298.pdf)
16. [W. Dabney et al., "Implicit Quantile Networks for Distributional Reinforcement Learning." arXiv preprint arXiv:1806.06923, 2018.](https://arxiv.org/pdf/1806.06923.pdf)
15. [M. Fortunato et al., "Noisy Networks for Exploration." arXiv preprint arXiv:1706.10295, 2017.](https://arxiv.org/pdf/1706.10295.pdf)
16. [M. Hessel et al., "Rainbow: Combining Improvements in Deep Reinforcement Learning." arXiv preprint arXiv:1710.02298, 2017.](https://arxiv.org/pdf/1710.02298.pdf)
17. [W. Dabney et al., "Implicit Quantile Networks for Distributional Reinforcement Learning." arXiv preprint arXiv:1806.06923, 2018.](https://arxiv.org/pdf/1806.06923.pdf)
2 changes: 0 additions & 2 deletions algorithms/common/abstract/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ class RewardFn(ABC):
"""Abstract class for computing reward.
New compute_reward class should redefine __call__()
Attributes:
"""

@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion algorithms/common/buffer/segment_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def sum(self, start: int = 0, end: int = 0) -> float:

def retrieve(self, upperbound: float) -> int:
"""Find the highest index `i` about upper bound in the tree"""
assert 0 <= upperbound <= self.sum() + 1e-5
# TODO: Check assert case and fix bug
assert 0 <= upperbound <= self.sum() + 1e-5, "upperbound: {}".format(upperbound)

idx = 1

Expand Down
2 changes: 1 addition & 1 deletion algorithms/common/env/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import cv2
import gym
from gym import spaces
import gym.spaces as spaces
import numpy as np

os.environ.setdefault("PATH", "")
Expand Down
35 changes: 21 additions & 14 deletions algorithms/common/networks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def concat(
return in_concat


def init_layer_uniform(layer: nn.Linear, init_w: float = 3e-3) -> nn.Linear:
"""Init uniform parameters on the single layer"""
layer.weight.data.uniform_(-init_w, init_w)
layer.bias.data.uniform_(-init_w, init_w)

return layer


class MLP(nn.Module):
"""Baseline of Multilayer perceptron.
Expand All @@ -53,9 +61,10 @@ def __init__(
hidden_sizes: list,
hidden_activation: Callable = F.relu,
output_activation: Callable = identity,
linear_layer: nn.Module = nn.Linear,
use_output_layer: bool = True,
n_category: int = -1,
init_w: float = 3e-3,
init_fn: Callable = init_layer_uniform,
):
"""Initialization.
Expand All @@ -65,9 +74,10 @@ def __init__(
hidden_sizes (list): number of hidden layers
hidden_activation (function): activation function of hidden layers
output_activation (function): activation function of output layer
linear_layer (nn.Module): linear layer of mlp
use_output_layer (bool): whether or not to use the last layer
n_category (int): category number (-1 if the action is continuous)
init_w (float): weight initialization bound for the last layer
init_fn (Callable): weight initialization function bound for the last layer
"""
super(MLP, self).__init__()
Expand All @@ -77,23 +87,23 @@ def __init__(
self.output_size = output_size
self.hidden_activation = hidden_activation
self.output_activation = output_activation
self.linear_layer = linear_layer
self.use_output_layer = use_output_layer
self.n_category = n_category

# set hidden layers
self.hidden_layers: list = []
in_size = self.input_size
for i, next_size in enumerate(hidden_sizes):
fc = nn.Linear(in_size, next_size)
fc = self.linear_layer(in_size, next_size)
in_size = next_size
self.__setattr__("hidden_fc{}".format(i), fc)
self.hidden_layers.append(fc)

# set output layers
if self.use_output_layer:
self.output_layer = nn.Linear(in_size, output_size)
self.output_layer.weight.data.uniform_(-init_w, init_w)
self.output_layer.bias.data.uniform_(-init_w, init_w)
self.output_layer = self.linear_layer(in_size, output_size)
self.output_layer = init_fn(self.output_layer)
else:
self.output_layer = identity
self.output_activation = identity
Expand Down Expand Up @@ -137,7 +147,7 @@ def __init__(
mu_activation: Callable = torch.tanh,
log_std_min: float = -20,
log_std_max: float = 2,
init_w: float = 3e-3,
init_fn: Callable = init_layer_uniform,
):
"""Initialization."""
super(GaussianDist, self).__init__(
Expand All @@ -155,13 +165,11 @@ def __init__(

# set log_std layer
self.log_std_layer = nn.Linear(in_size, output_size)
self.log_std_layer.weight.data.uniform_(-init_w, init_w)
self.log_std_layer.bias.data.uniform_(-init_w, init_w)
self.log_std_layer = init_fn(self.log_std_layer)

# set mean layer
self.mu_layer = nn.Linear(in_size, output_size)
self.mu_layer.weight.data.uniform_(-init_w, init_w)
self.mu_layer.bias.data.uniform_(-init_w, init_w)
self.mu_layer = init_fn(self.mu_layer)

def get_dist_params(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Return gausian distribution parameters."""
Expand Down Expand Up @@ -229,7 +237,7 @@ def __init__(
output_size: int,
hidden_sizes: list,
hidden_activation: Callable = F.relu,
init_w: float = 3e-3,
init_fn: Callable = init_layer_uniform,
):
"""Initialization."""
super(CategoricalDist, self).__init__(
Expand All @@ -244,8 +252,7 @@ def __init__(

# set log_std layer
self.last_layer = nn.Linear(in_size, output_size)
self.last_layer.weight.data.uniform_(-init_w, init_w)
self.last_layer.bias.data.uniform_(-init_w, init_w)
self.last_layer = init_fn(self.last_layer)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Forward method implementation."""
Expand Down
33 changes: 26 additions & 7 deletions algorithms/dqn/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
https://arxiv.org/pdf/1509.06461.pdf (Double DQN)
https://arxiv.org/pdf/1511.05952.pdf (PER)
https://arxiv.org/pdf/1511.06581.pdf (Dueling)
https://arxiv.org/pdf/1706.10295.pdf (NoisyNet)
https://arxiv.org/pdf/1707.06887.pdf (C51)
https://arxiv.org/pdf/1710.02298.pdf (Rainbow)
https://arxiv.org/pdf/1806.06923.pdf (IQN)
"""

import argparse
import datetime
import os
import time
from typing import Tuple

import gym
Expand Down Expand Up @@ -191,7 +193,7 @@ def _get_dqn_loss(
gamma=gamma,
)

def update_model(self) -> torch.Tensor:
def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Train the model after each episode."""
# 1 step loss
experiences_1 = self.memory.sample(self.beta)
Expand Down Expand Up @@ -239,6 +241,10 @@ def update_model(self) -> torch.Tensor:
fraction = min(float(self.i_episode) / self.args.episode_num, 1.0)
self.beta = self.beta + fraction * (1.0 - self.beta)

if self.hyper_params["USE_NOISY_NET"]:
self.dqn.reset_noise()
self.dqn_target.reset_noise()

return loss.data, q_values.mean().data

def load_params(self, path: str):
Expand All @@ -263,11 +269,11 @@ def save_params(self, n_episode: int):

Agent.save_params(self, params, n_episode)

def write_log(self, i: int, loss: np.ndarray, score: float):
def write_log(self, i: int, loss: np.ndarray, score: float, avg_time_cost: float):
"""Write log about loss and score"""
print(
"[INFO] episode %d, episode step: %d, total step: %d, total score: %f\n"
"epsilon: %f, loss: %f, avg q-value: %f at %s\n"
"epsilon: %f, loss: %f, avg q-value: %f (spent %.6f sec/step)\n"
% (
i,
self.episode_step,
Expand All @@ -276,12 +282,20 @@ def write_log(self, i: int, loss: np.ndarray, score: float):
self.epsilon,
loss[0],
loss[1],
datetime.datetime.now(),
avg_time_cost,
)
)

if self.args.log:
wandb.log({"score": score, "dqn loss": loss[0], "epsilon": self.epsilon})
wandb.log(
{
"score": score,
"epsilon": self.epsilon,
"dqn loss": loss[0],
"avg q values": loss[1],
"time per each step": avg_time_cost,
}
)

# pylint: disable=no-self-use, unnecessary-pass
def pretrain(self):
Expand Down Expand Up @@ -312,6 +326,8 @@ def train(self):
done = False
score = 0

t_begin = time.time()

while not done:
if self.args.render and self.i_episode >= self.args.render_after:
self.env.render()
Expand All @@ -334,9 +350,12 @@ def train(self):
state = next_state
score += reward

t_end = time.time()
avg_time_cost = (t_end - t_begin) / self.episode_step

if losses:
avg_loss = np.vstack(losses).mean(axis=0)
self.write_log(self.i_episode, avg_loss, score)
self.write_log(self.i_episode, avg_loss, score, avg_time_cost)

if self.i_episode % self.args.save_period == 0:
self.save_params(self.i_episode)
Expand Down
114 changes: 114 additions & 0 deletions algorithms/dqn/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
"""Linear module for dqn algorithms
- Author: Kh Kim
- Contact: kh.kim@medipixel.io
"""

import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class NoisyLinear(nn.Module):
"""Noisy linear module for NoisyNet.
References:
https://github.com/higgsfield/RL-Adventure/blob/master/5.noisy%20dqn.ipynb
https://github.com/Kaixhin/Rainbow/blob/master/model.py
Attributes:
in_features (int): input size of linear module
out_features (int): output size of linear module
std_init (float): initial std value
weight_mu (nn.Parameter): mean value weight parameter
weight_sigma (nn.Parameter): std value weight parameter
bias_mu (nn.Parameter): mean value bias parameter
bias_sigma (nn.Parameter): std value bias parameter
"""

def __init__(self, in_features: int, out_features: int, std_init: float = 0.5):
"""Initialization."""
super(NoisyLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.std_init = std_init

self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
self.weight_sigma = nn.Parameter(torch.Tensor(out_features, in_features))
self.register_buffer("weight_epsilon", torch.Tensor(out_features, in_features))

self.bias_mu = nn.Parameter(torch.Tensor(out_features))
self.bias_sigma = nn.Parameter(torch.Tensor(out_features))
self.register_buffer("bias_epsilon", torch.Tensor(out_features))

self.reset_parameters()
self.reset_noise()

def reset_parameters(self):
"""Reset trainable network parameters (factorized gaussian noise)."""
mu_range = 1 / math.sqrt(self.in_features)
self.weight_mu.data.uniform_(-mu_range, mu_range)
self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features))
self.bias_mu.data.uniform_(-mu_range, mu_range)
self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features))

@staticmethod
def scale_noise(size: int) -> torch.Tensor:
"""Set scale to make noise (factorized gaussian noise)."""
x = torch.FloatTensor(np.random.normal(loc=0.0, scale=1.0, size=size))

return x.sign().mul(x.abs().sqrt())

def reset_noise(self):
"""Make new noise."""
epsilon_in = self.scale_noise(self.in_features)
epsilon_out = self.scale_noise(self.out_features)

# outer product
self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
self.bias_epsilon.copy_(epsilon_out)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward method implementation.
We don't use separate statements on train / eval mode.
It doesn't show remarkable difference of performance.
"""
return F.linear(
x,
self.weight_mu + self.weight_sigma * self.weight_epsilon,
self.bias_mu + self.bias_sigma * self.bias_epsilon,
)


class NoisyLinearConstructor:
"""Constructor class for changing hyper parameters of NoisyLinear.
Attributes:
std_init (float): initial std value
"""

def __init__(self, std_init: float = 0.5):
"""Initialization."""
self.std_init = std_init

def __call__(self, in_features: int, out_features: int) -> NoisyLinear:
"""Return NoisyLinear instance set hyper parameters"""
return NoisyLinear(in_features, out_features, self.std_init)


class NoisyMLPHandler:
"""Includes methods to handle noisy linear."""

def reset_noise(self):
"""Re-sample noise"""
for _, module in self.named_children():
module.reset_noise()
Loading

0 comments on commit d79c356

Please sign in to comment.