Skip to content

Commit

Permalink
Add C51 algorithm (DLR-RM/stable-baselines3#622)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiabir committed Dec 5, 2023
1 parent 9f333ff commit e744839
Show file tree
Hide file tree
Showing 4 changed files with 658 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.c51 import C51
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ARS",
"C51",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
Expand Down
4 changes: 4 additions & 0 deletions sb3_contrib/c51/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.c51.c51 import C51
from sb3_contrib.c51.policies import CnnPolicy, MlpPolicy, MultiInputPolicy

__all__ = ["C51", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
316 changes: 316 additions & 0 deletions sb3_contrib/c51/c51.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
import warnings
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update

from sb3_contrib.c51.policies import C51Policy, CategoricalNetwork, CnnPolicy, MlpPolicy, MultiInputPolicy

SelfC51 = TypeVar("SelfC51", bound="C51")


def project(supports, weights, target_support):
"""Projects a batch of (support, weights) onto target_support.
Based on equation (7) in (Bellemare et al., 2017): https://arxiv.org/abs/1707.06887
In the rest of the comments we will refer to this equation simply as Eq7.
Args:
supports: Batch of supports.
weights: Batch of weights.
target_support: Target support.
Returns:
Batch of weights after projection.
"""
v_min, v_max = target_support[0], target_support[-1]
# `N` in Eq7.
n_atoms = target_support.shape[0]
# delta_z = `\Delta z` in Eq7.
delta_z = (v_max - v_min) / (n_atoms - 1)
# clipped_support = `[\hat{T}_{z_j}]^{V_max}_{V_min}` in Eq7.
clipped_support = th.clip(supports, v_min, v_max)
# numerator = `|clipped_support - z_i|` in Eq7.
numerator = th.abs(clipped_support[:, None] - target_support[:, None])
quotient = 1 - (numerator / delta_z)
# clipped_quotient = `[1 - numerator / (\Delta z)]_0^1` in Eq7.
clipped_quotient = th.clip(quotient, 0, 1)
# inner_prod = `\sum_{j=0}^{N-1} clipped_quotient * p_j(x', \pi(x'))` in Eq7.
inner_prod = clipped_quotient * weights[:, None]
return th.sum(inner_prod, dim=-1)


class C51(OffPolicyAlgorithm):
"""
Categorical Deep Q-Network (C51)
Paper: https://arxiv.org/abs/1707.06887
Default hyperparameters are taken from the paper and are tuned for Atari games.
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout
(see ``train_freq`` and ``n_episodes_rollout``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param target_update_interval: update the target network every ``target_update_interval``
environment steps.
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
:param exploration_initial_eps: initial value of random action probability
:param exploration_final_eps: final value of random action probability
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""

policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
# Linear schedule will be defined in `_setup_model()`
exploration_schedule: Schedule
categorical_net: CategoricalNetwork
categorical_net_target: CategoricalNetwork
policy: C51Policy

def __init__(
self,
policy: Union[str, Type[C51Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 2.5e-4,
buffer_size: int = 1000000, # 1e6
learning_starts: int = 50000,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: int = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.01,
max_grad_norm: Optional[float] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
action_noise=None, # No action noise
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
seed=seed,
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(spaces.Discrete,),
support_multi_env=True,
)

self.exploration_initial_eps = exploration_initial_eps
self.exploration_final_eps = exploration_final_eps
self.exploration_fraction = exploration_fraction
self.target_update_interval = target_update_interval
# For updating the target network with multiple envs:
self._n_calls = 0
self.max_grad_norm = max_grad_norm
# "epsilon" for the epsilon-greedy exploration
self.exploration_rate = 0.0

if "optimizer_class" not in self.policy_kwargs:
self.policy_kwargs["optimizer_class"] = th.optim.Adam
# Proposed in the C51 paper where `batch_size = 32`
self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size)

if _init_setup_model:
self._setup_model()

def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
self.batch_norm_stats = get_parameters_by_name(self.categorical_net, ["running_"])
self.batch_norm_stats_target = get_parameters_by_name(self.categorical_net_target, ["running_"])
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
)
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self.n_envs > 1:
if self.n_envs > self.target_update_interval:
warnings.warn(
"The number of environments used is greater than the target network "
f"update interval ({self.n_envs} > {self.target_update_interval}), "
"therefore the target network will be updated after each call to env.step() "
f"which corresponds to {self.n_envs} steps."
)

self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)

def _create_aliases(self) -> None:
self.categorical_net = self.policy.categorical_net
self.categorical_net_target = self.policy.categorical_net_target
self.support = self.categorical_net.support

def _on_step(self) -> None:
"""
Update the exploration rate and target network if needed.
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
polyak_update(self.categorical_net.parameters(), self.categorical_net_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
self.logger.record("rollout/exploration_rate", self.exploration_rate)

def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)

losses = []
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]

with th.no_grad():
# Compute the next categorical probabilities using the target network
next_probabilities = th.softmax(self.categorical_net_target(replay_data.next_observations), dim=-1)
# Compute the greedy actions which maximize the next Q values
next_actions = (next_probabilities * self.support).mean(dim=-1).argmax(dim=-1)
# Follow greedy policy: use the one with the highest Q values
next_probabilities = next_probabilities[th.arange(self.batch_size), next_actions]
# 1-step TD target
target_support = replay_data.rewards + (1 - replay_data.dones) * self.gamma * self.support
# Project
targets = project(target_support, next_probabilities, self.support)

# Get current estimated categorical logits
logits = self.categorical_net(replay_data.observations)
logits = logits[np.arange(self.batch_size), replay_data.actions.squeeze()]

# Compute cross-entropy loss
loss = th.nn.functional.cross_entropy(logits, targets)
losses.append(loss.item())

# Optimize the policy
self.policy.optimizer.zero_grad()
loss.backward()
# Clip gradient norm
if self.max_grad_norm is not None:
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()

# Increase update counter
self._n_updates += gradient_steps

self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/loss", np.mean(losses))

def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state (used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict):
n_batch = observation[next(iter(observation.keys()))].shape[0]
else:
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, episode_start, deterministic)
return action, state

def learn(
self: SelfC51,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "C51",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfC51:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

def _excluded_save_params(self) -> List[str]:
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] # noqa: RUF005

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

return state_dicts, []
Loading

0 comments on commit e744839

Please sign in to comment.