diff --git a/README.md b/README.md index 309d524a2..7031d9fe8 100644 --- a/README.md +++ b/README.md @@ -162,8 +162,8 @@ All the following examples can be executed online using Google colab notebooks: | **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** | | ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- | -| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | -| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | +| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | | TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 34b97ec1a..94dc43b23 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -5,14 +5,14 @@ This table displays the rl algorithms that are implemented in the Stable Baselin along with some useful characteristics: support for discrete/continuous actions, multiprocessing. -============ =========== ============ ================ -Name ``Box`` ``Discrete`` Multi Processing -============ =========== ============ ================ -A2C ✔️ ✔️ ✔️ -PPO ✔️ ✔️ ✔️ -SAC ✔️ ❌ ❌ -TD3 ✔️ ❌ ❌ -============ =========== ============ ================ +============ =========== ============ ================= =============== ================ +Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing +============ =========== ============ ================= =============== ================ +A2C ✔️ ✔️ ✔️ ✔️ ✔️ +PPO ✔️ ✔️ ✔️ ✔️ ✔️ +SAC ✔️ ❌ ❌ ❌ ❌ +TD3 ✔️ ❌ ❌ ❌ ❌ +============ =========== ============ ================= =============== ================ .. note:: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8d7290da4..cbdf36f9d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.6.0a8 (WIP) +Pre-Release 0.6.0a9 (WIP) ------------------------------ Breaking Changes: @@ -15,10 +15,9 @@ New Features: - Added env checker (Sync with Stable Baselines) - Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines) - Added determinism tests -- Added ``cmd_utils`` and ``atari_wrappers`` -- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation / action spaces for PPO and A2C (@rolandgvc) -- Added ``MultiCategorical`` and ``Bernoulli`` distributions (@rolandgvc) -- Added ``test_bernoulli``, modified ``test_categorical`` and created ``test_spaces.py`` (@rolandgvc) +- Added ``cmd_utils`` and ``atari_wrappers`` +- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc) +- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc) Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 38374f7e5..096778ba1 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -28,10 +28,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ✔️ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ ============= ====== =========== diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index fb83c8985..22fdf150b 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -38,10 +38,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ✔️ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ ============= ====== =========== Example diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index 359df4bde..4e777886f 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -58,10 +58,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ❌ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ ============= ====== =========== diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index 02ae39184..86a939dee 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -50,10 +50,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ❌ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ ============= ====== =========== diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 437d1cc1c..187ac8b12 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -279,8 +279,8 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s elif observation.shape[1:] == observation_space.shape: return True else: - raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) - + "Box environment, please use {} ".format(observation_space.shape) + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for " + + f"Box environment, please use {observation_space.shape} " + "or (n_env, {}) for the observation shape." .format(", ".join(map(str, observation_space.shape)))) elif isinstance(observation_space, gym.spaces.Discrete): @@ -289,7 +289,7 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s elif len(observation.shape) == 1: return True else: - raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for " + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") elif isinstance(observation_space, gym.spaces.MultiDiscrete): @@ -298,21 +298,21 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): return True else: - raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) - + "environment, please use ({},) or ".format(len(observation_space.nvec)) - + "(n_env, {}) for the observation shape.".format(len(observation_space.nvec))) + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " + + f"environment, please use ({len(observation_space.nvec)},) or " + + f"(n_env, {len(observation_space.nvec)}) for the observation shape.") elif isinstance(observation_space, gym.spaces.MultiBinary): if observation.shape == (observation_space.n,): return False elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: return True else: - raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) - + "environment, please use ({},) or ".format(observation_space.n) - + "(n_env, {}) for the observation shape.".format(observation_space.n)) + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary " + + f"environment, please use ({observation_space.n},) or " + + f"(n_env, {observation_space.n}) for the observation shape.") else: - raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}." - .format(observation_space)) + raise ValueError("Error: Cannot determine if the observation is vectorized " + + f" with the space type {observation_space}.") def _get_data(self) -> Dict[str, Any]: """ @@ -447,7 +447,7 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ raise ValueError(f"Error: the policy type {base_policy_type} is not registered!") if name not in _policy_registry[base_policy_type]: raise ValueError(f"Error: unknown policy type {name}," - "the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") + f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") return _policy_registry[base_policy_type][name] @@ -460,14 +460,10 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None: :param policy: (Type[BasePolicy]) the policy class """ sub_class = None - # For building the doc - try: - for cls in BasePolicy.__subclasses__(): - if issubclass(policy, cls): - sub_class = cls - break - except AttributeError: - sub_class = str(th.random.randint(100)) + for cls in BasePolicy.__subclasses__(): + if issubclass(policy, cls): + sub_class = cls + break if sub_class is None: raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index f3caa94e3..849756f17 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -110,6 +110,8 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int: :param observation_space: (spaces.Space) :return: (int) """ + # See issue https://github.com/openai/gym/issues/1915 + # it may be a problem for Dict/Tuple spaces too... if isinstance(observation_space, spaces.MultiDiscrete): return sum(observation_space.nvec) else: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index df3ddb5f4..21c95036e 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.6.0a8 +0.6.0a9 diff --git a/tests/test_identity.py b/tests/test_identity.py index d937c7e96..b41b70c2f 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -2,17 +2,27 @@ import pytest from stable_baselines3 import A2C, PPO, SAC, TD3 -from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv +from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv, + IdentityEnvMultiBinary, IdentityEnvMultiDiscrete) + +from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.noise import NormalActionNoise +DIM = 4 + + @pytest.mark.parametrize("model_class", [A2C, PPO]) -def test_discrete(model_class): - env = IdentityEnv(10) - model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000) +@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) +def test_discrete(model_class, env): + env = DummyVecEnv([lambda: env]) + model = model_class('MlpPolicy', env, gamma=0.5, seed=1).learn(3000) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) + obs = env.reset() + + assert np.shape(model.predict(obs)[0]) == np.shape(obs) @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3]) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 668ac4a3e..dfd4a60e3 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,49 +1,47 @@ import numpy as np import pytest +import gym -from stable_baselines3 import A2C, PPO -from stable_baselines3.common.identity_env import IdentityEnvMultiBinary, IdentityEnvMultiDiscrete +from stable_baselines3 import SAC, TD3 from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.vec_env import DummyVecEnv -MODEL_LIST = [A2C, PPO] -DIM = 4 +class DummyMultiDiscreteSpace(gym.Env): + def __init__(self, nvec): + super(DummyMultiDiscreteSpace, self).__init__() + self.observation_space = gym.spaces.MultiDiscrete(nvec) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + def reset(self): + return self.observation_space.sample() -@pytest.mark.parametrize("model_class", MODEL_LIST) -def test_identity_multidiscrete(model_class): - """ - Test if the algorithm (with a given policy) - can learn an identity transformation (i.e. return observation as an action) - with a multidiscrete action space - :param model_class: (BaseRLModel) A RL Model - """ - env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(DIM)]) + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} - model = model_class("MlpPolicy", env, gamma=0.5, seed=1) - model.learn(total_timesteps=3000) - obs = env.reset() - evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) +class DummyMultiBinary(gym.Env): + def __init__(self, n): + super(DummyMultiBinary, self).__init__() + self.observation_space = gym.spaces.MultiBinary(n) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - assert np.shape(model.predict(obs)[0]) == np.shape(obs) + def reset(self): + return self.observation_space.sample() + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} -@pytest.mark.parametrize("model_class", MODEL_LIST) -def test_identity_multibinary(model_class): + +@pytest.mark.parametrize("model_class", [SAC, TD3]) +@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)]) +def test_identity_spaces(model_class, env): """ - Test if the algorithm (with a given policy) - can learn an identity transformation (i.e. return observation as an action) - with a multibinary action space - :param model_class: (BaseRLModel) A RL Model + Additional tests for SAC/TD3 to check observation space support + for MultiDiscrete and MultiBinary. """ - env = DummyVecEnv([lambda: IdentityEnvMultiBinary(DIM)]) - - model = model_class("MlpPolicy", env, gamma=0.5, seed=1) - model.learn(total_timesteps=3000) - obs = env.reset() + env = gym.wrappers.TimeLimit(env, max_episode_steps=100) - evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) + model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64])) + model.learn(total_timesteps=500) - assert np.shape(model.predict(obs)[0]) == np.shape(obs) + evaluate_policy(model, env, n_eval_episodes=5)