diff --git a/src/imitation/testing/envs.py b/src/imitation/testing/envs.py deleted file mode 100644 index 72978dd1a..000000000 --- a/src/imitation/testing/envs.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Helper methods for tests of custom Gym environments. - -This is used in the imitation test suite and may also be useful for users of this -library. -""" - -import gym - - -def test_model_based(env: gym.Env) -> None: - """Smoke test for each of the ModelBasedEnv methods with type checks. - - Args: - env: The environment to test. - - Raises: - AssertionError if test fails. - """ - state = env.initial_state() - assert env.state_space.contains(state) - - action = env.action_space.sample() - new_state = env.transition(state, action) - assert env.state_space.contains(new_state) - - reward = env.reward(state, action, new_state) - assert isinstance(reward, float) - - done = env.terminal(state, 0) - assert isinstance(done, bool) - - obs = env.obs_from_state(state) - assert env.observation_space.contains(obs) - next_obs = env.obs_from_state(new_state) - assert env.observation_space.contains(next_obs) diff --git a/tests/test_envs.py b/tests/test_envs.py deleted file mode 100644 index 194a0f596..000000000 --- a/tests/test_envs.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Tests for seal environments.""" -from typing import List - -import gym -import pytest -from seals.testing import envs as seals_test - -from imitation.testing import envs as imitation_test - -ENV_NAMES = [ - env_spec.id - for env_spec in gym.envs.registration.registry.all() - if env_spec.id.startswith("imitation/") -] - -DETERMINISTIC_ENVS: List[str] = [] - -env = pytest.fixture(seals_test.make_env_fixture(skip_fn=pytest.skip)) - - -@pytest.mark.parametrize("env_name", ENV_NAMES) -class TestEnvs: - """Battery of simple tests for environments.""" - - def test_seed(self, env, env_name): - seals_test.test_seed(env, env_name, DETERMINISTIC_ENVS) - - def test_premature_step(self, env): - """Test that you must call reset() before calling step().""" - seals_test.test_premature_step( - env, - skip_fn=pytest.skip, - raises_fn=pytest.raises, - ) - - def test_model_based(self, env): - """Smoke test for each of the ModelBasedEnv methods with type checks.""" - if not hasattr(env, "pomdp_state_space"): # pragma: no cover - pytest.skip("This test is only for subclasses of ResettableEnv.") - - imitation_test.test_model_based(env) - - def test_rollout_schema(self, env: gym.Env): - """Tests if environments have correct types on `step()` and `reset()`.""" - seals_test.test_rollout_schema(env) - - def test_render(self, env: gym.Env): - """Tests `render()` supports modes specified in environment metadata.""" - seals_test.test_render(env, raises_fn=pytest.raises)