Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gymnasium Compatibility #735

Merged
merged 64 commits into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
4de8aae
Import gymnasium as gym everywhere and remove gym-specific fixes in t…
ernestum Jun 29, 2023
28c2e3f
Update dependencies
Rocamonde Jul 1, 2023
70e8652
gymnasium compatible reset, step and seed
EdoardoPona Jul 28, 2023
99f8a22
Update dependencies.
ernestum Aug 29, 2023
ae322eb
Fix gym->gymnasium import in SQIL implementation.
ernestum Aug 29, 2023
34b3240
Fix type violations found by mytype.
ernestum Aug 29, 2023
2840407
Downgrade furo for sphinx comaptibility.
ernestum Aug 29, 2023
21b642a
Fix more mypy issues by introducing runtime type checks.
ernestum Sep 1, 2023
ce42eb9
Ensure up-to-date pip version that does modern dependency resolution …
ernestum Sep 6, 2023
5927a01
Set minimum version for gymnasium dependency.
ernestum Sep 6, 2023
73f2879
Specify seals minimum version.
ernestum Sep 6, 2023
4abd8eb
Fix bug in make_vec_env
ernestum Sep 6, 2023
6141ebd
Use max_episode_steps parameter of gym.make instead of constructing t…
ernestum Sep 6, 2023
e6759a4
Add missing trailing commas.
ernestum Sep 6, 2023
c0460a9
Add missing raises section in generate_trajectories()
ernestum Sep 6, 2023
83f18f8
Remove unused imports.
ernestum Sep 6, 2023
bb93264
Add seals[atari] to test requirements.
ernestum Sep 6, 2023
748706d
make ATARI_REQUIRE equivalent to seals[atari]
ernestum Sep 6, 2023
331e3d5
Fix signature of VideoWrapper so it properly overrides gymnasium.Wrap…
ernestum Sep 8, 2023
a65ffd0
Fix the way we access the spec.
ernestum Sep 8, 2023
dbc3a76
Use cast instead of assert for mypy
ernestum Sep 8, 2023
e27a41e
Place mypy assert hint in a better place.
ernestum Sep 8, 2023
746fab5
Fix test_replay_buffer_init_errors test.
ernestum Sep 8, 2023
c497afc
Fix implementation of ObsRewHalveWrapper.
ernestum Sep 8, 2023
ce61d69
Set huggingface_sb3 version to the correct one.
ernestum Sep 11, 2023
458d729
Don't install gym[atari] but just gym as a hack.
ernestum Sep 11, 2023
b24934c
Adapt logger test to new "None" representation of sb3 logger.
ernestum Sep 11, 2023
3def9a7
Make sb3 and hf-sb3 version specification less strict.
ernestum Sep 11, 2023
c7cd80d
Fix isort issue.
ernestum Sep 11, 2023
fba9625
Fix mypy issues.
ernestum Sep 11, 2023
82687a9
Import gymnasium instead of gym in interactive policies.
ernestum Sep 12, 2023
cdefe75
Fix trailing whitespace
ernestum Sep 12, 2023
2f297a1
Fix typing issue in interactive.py
ernestum Sep 12, 2023
92f403d
Pull seals from feature branch
ernestum Sep 12, 2023
23c7d77
Fix typing issue in test_interactive.py
ernestum Sep 12, 2023
1f1c2cb
Hacky pipeline fixes to load outdated expert models and disable windo…
ernestum Sep 8, 2023
9ec9469
Install gym in macos unit tests. TODO REMOVE BEFORE MERGING
ernestum Sep 12, 2023
f51b81d
Add missing documentation.
ernestum Sep 18, 2023
9976216
Remove unused import.
ernestum Sep 18, 2023
71e948a
Re-add macos and windows pipelines.
ernestum Sep 18, 2023
090b77b
Fix wrong == None comparison.
ernestum Sep 18, 2023
2cfb17a
Set seals version to upstream again.
ernestum Sep 19, 2023
d65abde
Undo hacky installation of gym in the pipeline as a temporary fix.
ernestum Sep 19, 2023
1f9338b
Update expert model in testdata folder to gymnasium.
ernestum Sep 19, 2023
98b8de8
Update SQIL tutorial to gymnasium.
ernestum Sep 19, 2023
5d5355c
Remove the requirement for specific setuptools and pip versions from …
ernestum Sep 19, 2023
235488a
Remove mujoco from the installation instructions (will be pulled as a…
ernestum Sep 19, 2023
9f8e3a3
Remove reference to gym in reward_networks.rst
ernestum Sep 19, 2023
73982bb
Seed the vecenv using seed and not using reset.
ernestum Sep 19, 2023
eec6018
Adapt the custom environment in the custom env tutorial to the gymnas…
ernestum Sep 19, 2023
a69c78a
Use proper env name in gail tutorial.
ernestum Sep 19, 2023
5c41ab3
Adapt the baselines comparison tutorial to the gymnasium API.
ernestum Sep 19, 2023
b226d0e
Remove some more gym imports from tutorials and examples.
ernestum Sep 19, 2023
87049f3
Reformat buffer.py for better readability.
ernestum Sep 20, 2023
376cc9d
Add actionable hints when generate_trajectories is called with a venv…
ernestum Sep 20, 2023
b5a488f
Ignore info dict when generating rollouts for MCE tests.
ernestum Sep 20, 2023
fd86269
Remove note from MCE tests.
ernestum Sep 20, 2023
add90d7
Black fixes.
ernestum Sep 20, 2023
f903e77
Add missing trailing comma.
ernestum Sep 20, 2023
aa16cf9
Upgrade pre-commit tool versions.
ernestum Sep 20, 2023
ddff463
Flake8 fixes.
ernestum Sep 20, 2023
9dead7b
Black and codespell fixes.
ernestum Sep 20, 2023
efd79b0
Place inline comments in such a way that flake8 and black accept them.
ernestum Sep 21, 2023
d9c1658
Add isort fix and ensure isort detects wandb as thirdparty.
ernestum Sep 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ commands:
# Download and cache dependencies
- restore_cache:
keys:
- v7linux-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.sh" }}
- v8linux-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.sh" }}

- run:
name: install dependencies
Expand All @@ -75,7 +75,7 @@ commands:
- save_cache:
paths:
- /venv
key: v7linux-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.sh" }}
key: v8linux-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.sh" }}

- run:
name: install imitation
Expand Down
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
repos:
# Linting
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: check-ast
- id: trailing-whitespace
Expand All @@ -12,7 +12,7 @@ repos:
- id: check-toml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 23.9.1
hooks:
- id: black
- id: black-jupyter
Expand All @@ -22,7 +22,7 @@ repos:
- id: isort
# Python static analysis
- repo: https://github.com/pycqa/flake8
rev: '5.0.4'
rev: '6.1.0'
hooks:
- id: flake8
additional_dependencies:
Expand All @@ -34,7 +34,7 @@ repos:
- flake8-docstrings~=1.6.0
# Shell static analysis
- repo: https://github.com/koalaman/shellcheck-precommit
rev: v0.8.0
rev: v0.9.0
hooks:
- id: shellcheck
# precommit invokes shellcheck once per file. shellcheck complains if file
Expand All @@ -43,12 +43,12 @@ repos:
args: ["-e", "SC1091"]
# Misc
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
rev: v2.2.4
hooks:
- id: codespell
args: ["--skip=*.pyc,tests/testdata/*,*.ipynb,*.csv","--ignore-words-list=reacher,ith,iff"]
- repo: https://github.com/syntaqx/git-hooks
rev: v0.0.17
rev: v0.0.18
hooks:
- id: circleci-config-validate
# Hooks that run in local environment (not isolated venv) as they need
Expand Down
4 changes: 0 additions & 4 deletions ci/build_and_activate_venv.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,4 @@ If ($venv -eq $null) {

virtualenv -p python3.8 $venv
& $venv\Scripts\activate
# Note: We need to install these versions of setuptools and wheel to allow installing gym==0.21.0 on Windows.
# See https://github.com/freqtrade/freqtrade/issues/8376
# TODO(GH#707): remove pin once upgraded Gym
python -m pip install --upgrade pip wheel==0.38.4 setuptools==65.5.1
pip install ".[docs,parallel,test]"
5 changes: 3 additions & 2 deletions ci/build_and_activate_venv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ fi
virtualenv -p ${python_version} ${venv}
# shellcheck disable=SC1090,SC1091
source ${venv}/bin/activate
# Note: We need to install setuptools==66.1.1 to allow installing gym==0.21.0.
python -m pip install --upgrade pip setuptools==66.1.1

# Update pip to the latest version.
pip install --upgrade pip

# If platform is linux, install pytorch CPU version.
# This will prevent installing the CUDA version in the pip install ".[docs,parallel,test]" command.
Expand Down
1 change: 0 additions & 1 deletion ci/clean_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None:
print(f"Checking {file}")

for cell in nb.cells:

# Remove empty cells
if cell["cell_type"] == "code" and not cell["source"]:
if check_only:
Expand Down
4 changes: 2 additions & 2 deletions docs/algorithms/airl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
:skipif: skip_doctests

import numpy as np
import seals # noqa: F401 # needed to load "seals/" environments
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy
Expand All @@ -39,7 +39,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
SEED = 42

env = make_vec_env(
"seals/CartPole-v0",
"seals:seals/CartPole-v0",
rng=np.random.default_rng(SEED),
n_envs=8,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # to compute rollouts
Expand Down
4 changes: 2 additions & 2 deletions docs/algorithms/bc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Detailed example notebook: :doc:`../tutorials/1_train_bc`
:skipif: skip_doctests

import numpy as np
import seals # noqa: F401 # needed to load "seals/" environments
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy

from imitation.algorithms import bc
Expand All @@ -32,7 +32,7 @@ Detailed example notebook: :doc:`../tutorials/1_train_bc`

rng = np.random.default_rng(0)
env = make_vec_env(
"seals/CartPole-v0",
"seals:seals/CartPole-v0",
rng=rng,
n_envs=1,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # for computing rollouts
Expand Down
4 changes: 2 additions & 2 deletions docs/algorithms/dagger.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Detailed example notebook: :doc:`../tutorials/2_train_dagger`
import tempfile

import numpy as np
import seals # noqa: F401 # needed to load "seals/" environments
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy

from imitation.algorithms import bc
Expand All @@ -36,7 +36,7 @@ Detailed example notebook: :doc:`../tutorials/2_train_dagger`

rng = np.random.default_rng(0)
env = make_vec_env(
"seals/CartPole-v0",
"seals:seals/CartPole-v0",
rng=rng,
)
expert = load_policy(
Expand Down
4 changes: 2 additions & 2 deletions docs/algorithms/gail.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
:skipif: skip_doctests

import numpy as np
import seals # noqa: F401 # needed to load "seals/" environments
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy
Expand All @@ -36,7 +36,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
SEED = 42

env = make_vec_env(
"seals/CartPole-v0",
"seals:seals/CartPole-v0",
rng=np.random.default_rng(SEED),
n_envs=8,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # to compute rollouts
Expand Down
2 changes: 1 addition & 1 deletion docs/algorithms/sqil.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Detailed example notebook: :doc:`../tutorials/8_train_sqil`
:skipif: skip_doctests

import datasets
import gym
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv

Expand Down
13 changes: 1 addition & 12 deletions docs/getting-started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@ Prerequisites
-------------

- Python 3.8+
- Specific versions of pip and setuptools due to \
`a bug with gym <https://github.com/openai/gym/issues/3176>`_:

.. code-block:: bash

pip install -U setuptools==65.5.0 pip==21

- pip (it helps to make sure this is up-to-date: ``pip install -U pip``)
- (on ARM64 Macs) you need to set environment variables due to \
`a bug in grpcio <https://stackoverflow.com/questions/66640705/how-can-i-install-grpcio-on-an-apple-m1-silicon-laptop>`_:

Expand All @@ -23,11 +17,6 @@ Prerequisites

- (Optional) OpenGL (to render gym environments)
- (Optional) FFmpeg (to encode videos of renders)
- (Optional) MuJoCo (follow instructions to install `mujoco\_py v1.5 here`_)
dan-pandori marked this conversation as resolved.
Show resolved Hide resolved

.. _mujoco_py v1.5 here:
https://github.com/openai/mujoco-py/tree/498b451a03fb61e5bdfcb6956d8d7c881b1098b5#install-mujoco


Installation from PyPI
----------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/main-concepts/reward_networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ In order to use a reward network to train a policy, we need to integrate it into

import numpy as np
rng = np.random.default_rng(0)
from gym.spaces import Box
from gymnasium.spaces import Box
obs_space = Box(np.ones(2), np.ones(2))
action_space = Box(np.ones(5), np.ones(5))

Expand Down
37 changes: 17 additions & 20 deletions docs/tutorials/10_train_custom_env.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,30 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, Optional\n",
"from typing import Any\n",
"import numpy as np\n",
"import gym\n",
"import gymnasium as gym\n",
"\n",
"from gym.spaces import Box\n",
"from gym.utils import seeding\n",
"from gymnasium.spaces import Box\n",
"\n",
"\n",
"class ObservationMatchingEnv(gym.Env):\n",
" def __init__(self, num_options: int = 2):\n",
" self.state = None\n",
" self.num_options = num_options\n",
" self.observation_space = Box(0, 1, shape=(num_options,), dtype=np.float32)\n",
" self.action_space = Box(0, 1, shape=(num_options,), dtype=np.float32)\n",
" self.seed()\n",
" self.observation_space = Box(0, 1, shape=(num_options,))\n",
" self.action_space = Box(0, 1, shape=(num_options,))\n",
"\n",
" def seed(self, seed=None):\n",
" self.np_random, seed = seeding.np_random(seed)\n",
" return [seed]\n",
"\n",
" def reset(self):\n",
" self.state = self.np_random.uniform(size=self.num_options)\n",
" return self.state\n",
" def reset(self, seed: int = None, options: Optional[Dict[str, Any]] = None):\n",
" super().reset(seed=seed, options=options)\n",
" self.state = self.observation_space.sample()\n",
" return self.state, {}\n",
"\n",
" def step(self, action):\n",
" reward = -np.abs(self.state - action).mean()\n",
" self.state = self.np_random.uniform(size=self.num_options)\n",
" return self.state, reward, False, {}"
" self.state = self.observation_space.sample()\n",
" return self.state, reward, False, False, {}"
]
},
{
Expand Down Expand Up @@ -126,7 +124,7 @@
"metadata": {},
"outputs": [],
"source": [
"from gym.wrappers import TimeLimit\n",
"from gymnasium.wrappers import TimeLimit\n",
"from imitation.data import rollout\n",
"from imitation.data.wrappers import RolloutInfoWrapper\n",
"from imitation.util.util import make_vec_env\n",
Expand Down Expand Up @@ -176,7 +174,7 @@
"metadata": {},
"outputs": [],
"source": [
"from gym.wrappers import TimeLimit\n",
"from gymnasium.wrappers import TimeLimit\n",
"from imitation.data import rollout\n",
"from imitation.data.wrappers import RolloutInfoWrapper\n",
"from stable_baselines3.common.vec_env import DummyVecEnv\n",
Expand Down Expand Up @@ -236,7 +234,7 @@
"from stable_baselines3 import PPO\n",
"from stable_baselines3.ppo import MlpPolicy\n",
"from stable_baselines3.common.evaluation import evaluate_policy\n",
"from gym.wrappers import TimeLimit\n",
"from gymnasium.wrappers import TimeLimit\n",
"\n",
"expert = PPO(\n",
" policy=MlpPolicy,\n",
Expand Down Expand Up @@ -266,8 +264,7 @@
"# n_steps=64,\n",
"# )\n",
"expert.learn(10_000) # Note: set to 100000 to train a proficient expert\n",
"\n",
"reward, _ = evaluate_policy(expert, env, 10)\n",
"reward, _ = evaluate_policy(expert, expert.get_env(), 10)\n",
"print(f\"Expert reward: {reward}\")"
]
},
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/1_train_bc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
"import gymnasium as gym\n",
"from imitation.policies.serialize import load_policy\n",
"from imitation.util.util import make_vec_env\n",
"from imitation.data.wrappers import RolloutInfoWrapper\n",
"\n",
"env = make_vec_env(\n",
" \"seals/CartPole-v0\",\n",
" \"seals:seals/CartPole-v0\",\n",
" rng=np.random.default_rng(),\n",
" post_wrappers=[\n",
" lambda env, _: RolloutInfoWrapper(env)\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/2_train_dagger.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
"import gymnasium as gym\n",
"from imitation.policies.serialize import load_policy\n",
"from imitation.util.util import make_vec_env\n",
"\n",
"env = make_vec_env(\n",
" \"seals/CartPole-v0\",\n",
" \"seals:seals/CartPole-v0\",\n",
" rng=np.random.default_rng(),\n",
" n_envs=1,\n",
")\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/tutorials/3_train_gail.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
"from imitation.policies.serialize import load_policy\n",
"from imitation.util.util import make_vec_env\n",
"from imitation.data.wrappers import RolloutInfoWrapper\n",
"\n",
"SEED = 42\n",
"\n",
"env = make_vec_env(\n",
" \"seals/CartPole-v0\",\n",
" \"seals:seals/CartPole-v0\",\n",
" rng=np.random.default_rng(SEED),\n",
" n_envs=8,\n",
" post_wrappers=[\n",
Expand All @@ -45,7 +44,7 @@
"expert = load_policy(\n",
" \"ppo-huggingface\",\n",
" organization=\"HumanCompatibleAI\",\n",
" env_name=\"seals-CartPole-v0\",\n",
" env_name=\"seals:seals/CartPole-v0\",\n",
" venv=env,\n",
")"
]
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/4_train_airl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
"import gymnasium as gym\n",
"from imitation.policies.serialize import load_policy\n",
"from imitation.util.util import make_vec_env\n",
"from imitation.data.wrappers import RolloutInfoWrapper\n",
Expand Down Expand Up @@ -127,7 +127,7 @@
" reward_net=reward_net,\n",
")\n",
"\n",
"env.seed(SEED)\n",
"env.reset(seed=SEED)\n",
"learner_rewards_before_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
")\n",
Expand Down
2 changes: 2 additions & 0 deletions docs/tutorials/5_train_preference_comparisons.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"from imitation.algorithms import preference_comparisons\n",
"from imitation.rewards.reward_nets import BasicRewardNet\n",
"from imitation.util.networks import RunningNorm\n",
"from imitation.util.util import make_vec_env\n",
"from imitation.policies.base import FeedForward32Policy, NormalizeFeaturesExtractor\n",
"import gymnasium as gym\n",
"from stable_baselines3 import PPO\n",
"import numpy as np\n",
"\n",
Expand Down
Loading