forked from Stable-Baselines-Team/stable-baselines3-contrib
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add C51 algorithm (DLR-RM/stable-baselines3#622)
- Loading branch information
Showing
4 changed files
with
658 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from sb3_contrib.c51.c51 import C51 | ||
from sb3_contrib.c51.policies import CnnPolicy, MlpPolicy, MultiInputPolicy | ||
|
||
__all__ = ["C51", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
import warnings | ||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union | ||
|
||
import numpy as np | ||
import torch as th | ||
from gymnasium import spaces | ||
from stable_baselines3.common.buffers import ReplayBuffer | ||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm | ||
from stable_baselines3.common.policies import BasePolicy | ||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule | ||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update | ||
|
||
from sb3_contrib.c51.policies import C51Policy, CategoricalNetwork, CnnPolicy, MlpPolicy, MultiInputPolicy | ||
|
||
SelfC51 = TypeVar("SelfC51", bound="C51") | ||
|
||
|
||
def project(supports, weights, target_support): | ||
"""Projects a batch of (support, weights) onto target_support. | ||
Based on equation (7) in (Bellemare et al., 2017): https://arxiv.org/abs/1707.06887 | ||
In the rest of the comments we will refer to this equation simply as Eq7. | ||
Args: | ||
supports: Batch of supports. | ||
weights: Batch of weights. | ||
target_support: Target support. | ||
Returns: | ||
Batch of weights after projection. | ||
""" | ||
v_min, v_max = target_support[0], target_support[-1] | ||
# `N` in Eq7. | ||
n_atoms = target_support.shape[0] | ||
# delta_z = `\Delta z` in Eq7. | ||
delta_z = (v_max - v_min) / (n_atoms - 1) | ||
# clipped_support = `[\hat{T}_{z_j}]^{V_max}_{V_min}` in Eq7. | ||
clipped_support = th.clip(supports, v_min, v_max) | ||
# numerator = `|clipped_support - z_i|` in Eq7. | ||
numerator = th.abs(clipped_support[:, None] - target_support[:, None]) | ||
quotient = 1 - (numerator / delta_z) | ||
# clipped_quotient = `[1 - numerator / (\Delta z)]_0^1` in Eq7. | ||
clipped_quotient = th.clip(quotient, 0, 1) | ||
# inner_prod = `\sum_{j=0}^{N-1} clipped_quotient * p_j(x', \pi(x'))` in Eq7. | ||
inner_prod = clipped_quotient * weights[:, None] | ||
return th.sum(inner_prod, dim=-1) | ||
|
||
|
||
class C51(OffPolicyAlgorithm): | ||
""" | ||
Categorical Deep Q-Network (C51) | ||
Paper: https://arxiv.org/abs/1707.06887 | ||
Default hyperparameters are taken from the paper and are tuned for Atari games. | ||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) | ||
:param env: The environment to learn from (if registered in Gym, can be str) | ||
:param learning_rate: The learning rate, it can be a function | ||
of the current progress remaining (from 1 to 0) | ||
:param buffer_size: size of the replay buffer | ||
:param learning_starts: how many steps of the model to collect transitions for before learning starts | ||
:param batch_size: Minibatch size for each gradient update | ||
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update | ||
:param gamma: the discount factor | ||
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit | ||
like ``(5, "step")`` or ``(2, "episode")``. | ||
:param gradient_steps: How many gradient steps to do after each rollout | ||
(see ``train_freq`` and ``n_episodes_rollout``) | ||
Set to ``-1`` means to do as many gradient steps as steps done in the environment | ||
during the rollout. | ||
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). | ||
If ``None``, it will be automatically selected. | ||
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. | ||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer | ||
at a cost of more complexity. | ||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 | ||
:param target_update_interval: update the target network every ``target_update_interval`` | ||
environment steps. | ||
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced | ||
:param exploration_initial_eps: initial value of random action probability | ||
:param exploration_final_eps: final value of random action probability | ||
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping) | ||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average | ||
the reported success rate, mean episode length, and mean reward over | ||
:param tensorboard_log: the log location for tensorboard (if None, no logging) | ||
:param policy_kwargs: additional arguments to be passed to the policy on creation | ||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug | ||
:param seed: Seed for the pseudo random generators | ||
:param device: Device (cpu, cuda, ...) on which the code should be run. | ||
Setting it to auto, the code will be run on the GPU if possible. | ||
:param _init_setup_model: Whether or not to build the network at the creation of the instance | ||
""" | ||
|
||
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { | ||
"MlpPolicy": MlpPolicy, | ||
"CnnPolicy": CnnPolicy, | ||
"MultiInputPolicy": MultiInputPolicy, | ||
} | ||
# Linear schedule will be defined in `_setup_model()` | ||
exploration_schedule: Schedule | ||
categorical_net: CategoricalNetwork | ||
categorical_net_target: CategoricalNetwork | ||
policy: C51Policy | ||
|
||
def __init__( | ||
self, | ||
policy: Union[str, Type[C51Policy]], | ||
env: Union[GymEnv, str], | ||
learning_rate: Union[float, Schedule] = 2.5e-4, | ||
buffer_size: int = 1000000, # 1e6 | ||
learning_starts: int = 50000, | ||
batch_size: int = 32, | ||
tau: float = 1.0, | ||
gamma: float = 0.99, | ||
train_freq: int = 4, | ||
gradient_steps: int = 1, | ||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None, | ||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None, | ||
optimize_memory_usage: bool = False, | ||
target_update_interval: int = 10000, | ||
exploration_fraction: float = 0.1, | ||
exploration_initial_eps: float = 1.0, | ||
exploration_final_eps: float = 0.01, | ||
max_grad_norm: Optional[float] = None, | ||
stats_window_size: int = 100, | ||
tensorboard_log: Optional[str] = None, | ||
policy_kwargs: Optional[Dict[str, Any]] = None, | ||
verbose: int = 0, | ||
seed: Optional[int] = None, | ||
device: Union[th.device, str] = "auto", | ||
_init_setup_model: bool = True, | ||
): | ||
super().__init__( | ||
policy, | ||
env, | ||
learning_rate, | ||
buffer_size, | ||
learning_starts, | ||
batch_size, | ||
tau, | ||
gamma, | ||
train_freq, | ||
gradient_steps, | ||
action_noise=None, # No action noise | ||
replay_buffer_class=replay_buffer_class, | ||
replay_buffer_kwargs=replay_buffer_kwargs, | ||
policy_kwargs=policy_kwargs, | ||
stats_window_size=stats_window_size, | ||
tensorboard_log=tensorboard_log, | ||
verbose=verbose, | ||
device=device, | ||
seed=seed, | ||
sde_support=False, | ||
optimize_memory_usage=optimize_memory_usage, | ||
supported_action_spaces=(spaces.Discrete,), | ||
support_multi_env=True, | ||
) | ||
|
||
self.exploration_initial_eps = exploration_initial_eps | ||
self.exploration_final_eps = exploration_final_eps | ||
self.exploration_fraction = exploration_fraction | ||
self.target_update_interval = target_update_interval | ||
# For updating the target network with multiple envs: | ||
self._n_calls = 0 | ||
self.max_grad_norm = max_grad_norm | ||
# "epsilon" for the epsilon-greedy exploration | ||
self.exploration_rate = 0.0 | ||
|
||
if "optimizer_class" not in self.policy_kwargs: | ||
self.policy_kwargs["optimizer_class"] = th.optim.Adam | ||
# Proposed in the C51 paper where `batch_size = 32` | ||
self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size) | ||
|
||
if _init_setup_model: | ||
self._setup_model() | ||
|
||
def _setup_model(self) -> None: | ||
super()._setup_model() | ||
self._create_aliases() | ||
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 | ||
self.batch_norm_stats = get_parameters_by_name(self.categorical_net, ["running_"]) | ||
self.batch_norm_stats_target = get_parameters_by_name(self.categorical_net_target, ["running_"]) | ||
self.exploration_schedule = get_linear_fn( | ||
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction | ||
) | ||
# Account for multiple environments | ||
# each call to step() corresponds to n_envs transitions | ||
if self.n_envs > 1: | ||
if self.n_envs > self.target_update_interval: | ||
warnings.warn( | ||
"The number of environments used is greater than the target network " | ||
f"update interval ({self.n_envs} > {self.target_update_interval}), " | ||
"therefore the target network will be updated after each call to env.step() " | ||
f"which corresponds to {self.n_envs} steps." | ||
) | ||
|
||
self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) | ||
|
||
def _create_aliases(self) -> None: | ||
self.categorical_net = self.policy.categorical_net | ||
self.categorical_net_target = self.policy.categorical_net_target | ||
self.support = self.categorical_net.support | ||
|
||
def _on_step(self) -> None: | ||
""" | ||
Update the exploration rate and target network if needed. | ||
This method is called in ``collect_rollouts()`` after each step in the environment. | ||
""" | ||
self._n_calls += 1 | ||
if self._n_calls % self.target_update_interval == 0: | ||
polyak_update(self.categorical_net.parameters(), self.categorical_net_target.parameters(), self.tau) | ||
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 | ||
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) | ||
|
||
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) | ||
self.logger.record("rollout/exploration_rate", self.exploration_rate) | ||
|
||
def train(self, gradient_steps: int, batch_size: int = 100) -> None: | ||
# Switch to train mode (this affects batch norm / dropout) | ||
self.policy.set_training_mode(True) | ||
# Update learning rate according to schedule | ||
self._update_learning_rate(self.policy.optimizer) | ||
|
||
losses = [] | ||
for _ in range(gradient_steps): | ||
# Sample replay buffer | ||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] | ||
|
||
with th.no_grad(): | ||
# Compute the next categorical probabilities using the target network | ||
next_probabilities = th.softmax(self.categorical_net_target(replay_data.next_observations), dim=-1) | ||
# Compute the greedy actions which maximize the next Q values | ||
next_actions = (next_probabilities * self.support).mean(dim=-1).argmax(dim=-1) | ||
# Follow greedy policy: use the one with the highest Q values | ||
next_probabilities = next_probabilities[th.arange(self.batch_size), next_actions] | ||
# 1-step TD target | ||
target_support = replay_data.rewards + (1 - replay_data.dones) * self.gamma * self.support | ||
# Project | ||
targets = project(target_support, next_probabilities, self.support) | ||
|
||
# Get current estimated categorical logits | ||
logits = self.categorical_net(replay_data.observations) | ||
logits = logits[np.arange(self.batch_size), replay_data.actions.squeeze()] | ||
|
||
# Compute cross-entropy loss | ||
loss = th.nn.functional.cross_entropy(logits, targets) | ||
losses.append(loss.item()) | ||
|
||
# Optimize the policy | ||
self.policy.optimizer.zero_grad() | ||
loss.backward() | ||
# Clip gradient norm | ||
if self.max_grad_norm is not None: | ||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) | ||
self.policy.optimizer.step() | ||
|
||
# Increase update counter | ||
self._n_updates += gradient_steps | ||
|
||
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") | ||
self.logger.record("train/loss", np.mean(losses)) | ||
|
||
def predict( | ||
self, | ||
observation: Union[np.ndarray, Dict[str, np.ndarray]], | ||
state: Optional[Tuple[np.ndarray, ...]] = None, | ||
episode_start: Optional[np.ndarray] = None, | ||
deterministic: bool = False, | ||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: | ||
""" | ||
Get the policy action from an observation (and optional hidden state). | ||
Includes sugar-coating to handle different observations (e.g. normalizing images). | ||
:param observation: the input observation | ||
:param state: The last hidden states (can be None, used in recurrent policies) | ||
:param episode_start: The last masks (can be None, used in recurrent policies) | ||
:param deterministic: Whether or not to return deterministic actions. | ||
:return: the model's action and the next state (used in recurrent policies) | ||
""" | ||
if not deterministic and np.random.rand() < self.exploration_rate: | ||
if self.policy.is_vectorized_observation(observation): | ||
if isinstance(observation, dict): | ||
n_batch = observation[next(iter(observation.keys()))].shape[0] | ||
else: | ||
n_batch = observation.shape[0] | ||
action = np.array([self.action_space.sample() for _ in range(n_batch)]) | ||
else: | ||
action = np.array(self.action_space.sample()) | ||
else: | ||
action, state = self.policy.predict(observation, state, episode_start, deterministic) | ||
return action, state | ||
|
||
def learn( | ||
self: SelfC51, | ||
total_timesteps: int, | ||
callback: MaybeCallback = None, | ||
log_interval: int = 4, | ||
tb_log_name: str = "C51", | ||
reset_num_timesteps: bool = True, | ||
progress_bar: bool = False, | ||
) -> SelfC51: | ||
return super().learn( | ||
total_timesteps=total_timesteps, | ||
callback=callback, | ||
log_interval=log_interval, | ||
tb_log_name=tb_log_name, | ||
reset_num_timesteps=reset_num_timesteps, | ||
progress_bar=progress_bar, | ||
) | ||
|
||
def _excluded_save_params(self) -> List[str]: | ||
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] # noqa: RUF005 | ||
|
||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: | ||
state_dicts = ["policy", "policy.optimizer"] | ||
|
||
return state_dicts, [] |
Oops, something went wrong.