Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPRO - Feature Addition #272

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@
Changelog
==========

Release 2.5.0 (2025-01-27)
Release 2.6.0 (Unkown)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to PyTorch 2.3.0
- Dropped Python 3.8 support
- Upgraded to Stable-Baselines3 >= 2.5.0
-

New Features:
^^^^^^^^^^^^^
- Added Python 3.12 support
- Added Numpy v2.0 support
- Added GRPO policy

Bug Fixes:
^^^^^^^^^^
Expand All @@ -27,6 +24,8 @@ Others:
^^^^^^^

Documentation:


^^^^^^^^^^^^^^

Release 2.4.0 (2024-11-18)
Expand Down
163 changes: 163 additions & 0 deletions docs/modules/grpo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
.. _grpo:

.. automodule:: sb3_contrib.grpo

Generalized Policy Reward Optimization (GRPO)
=============================================

GRPO extends Proximal Policy Optimization (PPO) by introducing **generalized reward scaling** techniques.
Unlike standard PPO, which applies uniform reward normalization, GRPO **samples multiple candidate rewards per time step**
and optimizes policy updates based on a more informative reward distribution.

This approach improves **stability in reinforcement learning** and allows for **adaptive reward shaping** across complex environments.

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
CnnPolicy
MultiInputPolicy

Notes
-----

- Paper: *(Placeholder for a paper if applicable)*
- Blog post: *(Placeholder for related research or insights)*
- GRPO enables multi-sample updates and adaptive reward scaling for enhanced learning stability.

Can I use?
----------

- Recurrent policies: ❌
- Multi-processing: ✔️
- Gym spaces:

============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
Dict ✔️ ✔️
============= ====== ===========

.. warning::
If using GRPO with **multi-processing environments (`SubprocVecEnv`)**, ensure that the sampling method remains consistent
across parallel workers to avoid reward scaling inconsistencies.

.. warning::
When using **custom reward scaling functions**, validate that they do not introduce distribution shifts
that could destabilize training.


Example
-------

Train a GRPO agent on `CartPole-v1`. This example demonstrates how **reward scaling functions** can be customized.

.. code-block:: python

from sb3_contrib import GRPO
from stable_baselines3.common.vec_env import DummyVecEnv
import gymnasium as gym
import numpy as np

def custom_reward_scaling(rewards: np.ndarray) -> np.ndarray:
"""Example: Normalize rewards between -1 and 1."""
return np.clip(rewards / (np.abs(rewards).max() + 1e-8), -1, 1)

env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
model = GRPO("MlpPolicy", env, samples_per_time_step=5, reward_scaling_fn=custom_reward_scaling, verbose=1)

model.learn(total_timesteps=10_000)
model.save("grpo_cartpole")

obs, _ = env.reset()
while True:
action, _states = model.predict(obs)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, _ = env.reset()


Results
-------

Results for GRPO applied to the `CartPole-v1` environment show enhanced **stability and convergence** compared to standard PPO.

**Training Performance (10k steps)**
- GRPO achieves a **higher average episode reward** with fewer fluctuations.
- Multi-sample reward updates lead to **smoother policy improvements**.

Tensorboard logs can be visualized using:

.. code-block:: bash

tensorboard --logdir ./logs/grpo_cartpole

How to replicate the results?
-----------------------------

To replicate the performance of GRPO, follow these steps:

1. **Clone the repository**

.. code-block:: bash

git clone https://github.com/Stable-Baselines-Team/stable-baselines3-contrib.git
cd stable-baselines3-contrib

2. **Install dependencies**

.. code-block:: bash

pip install -e .

3. **Train the GRPO agent**

.. code-block:: bash

python scripts/train_grpo.py --env CartPole-v1 --samples_per_time_step 5

4. **View results with TensorBoard**

.. code-block:: bash

tensorboard --logdir ./logs/grpo_cartpole


Parameters
----------

.. autoclass:: GRPO
:members:
:inherited-members:


GRPO Policies
-------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: sb3_contrib.common.policies.ActorCriticPolicy
:members:
:noindex:

.. autoclass:: CnnPolicy
:members:

.. autoclass:: sb3_contrib.common.policies.ActorCriticCnnPolicy
:members:
:noindex:

.. autoclass:: MultiInputPolicy
:members:

.. autoclass:: sb3_contrib.common.policies.MultiInputActorCriticPolicy
:members:
:noindex:
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO
from sb3_contrib.grpo import GRPO

# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
Expand All @@ -21,4 +22,5 @@
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
"GRPO"
]
4 changes: 4 additions & 0 deletions sb3_contrib/grpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.grpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.grpo.grpo import GRPO

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "GRPO"]
Loading