From 0b5135b0d05410b4f31da54261aea9ca44d4edf4 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Mon, 21 Sep 2020 15:05:18 +0100 Subject: [PATCH] More detailed schema checks for environments (#37) * Explicitly check shape and dtype (dtype at least is not checked by default) * More informative test error messages * Return correct dtype in environments * Add tests for ObsCastWrapper * Lint * Black reformat --- src/seals/classic_control.py | 7 ++++- src/seals/diagnostics/largest_sum.py | 3 ++- src/seals/diagnostics/noisy_obs.py | 2 +- src/seals/diagnostics/parabola.py | 4 +-- src/seals/diagnostics/sort.py | 3 ++- src/seals/testing/envs.py | 40 ++++++++++++++++++---------- src/seals/util.py | 29 ++++++++++++++++++++ tests/conftest.py | 2 ++ tests/test_mujoco_rl.py | 6 +++-- tests/test_wrappers.py | 25 +++++++++++++++++ 10 files changed, 99 insertions(+), 22 deletions(-) diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index 6aaa385..9854095 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -32,6 +32,10 @@ def __init__(self): high = np.array(high) self.observation_space = spaces.Box(-high, high, dtype=np.float32) + def reset(self): + """Reset for FixedHorizonCartPole.""" + return super().reset().astype(np.float32) + def step(self, action): """Step function for FixedHorizonCartPole.""" with warnings.catch_warnings(): @@ -51,7 +55,7 @@ def step(self, action): ) rew = 1.0 if state_ok else 0.0 - return np.array(self.state), rew, False, {} + return np.array(self.state, dtype=np.float32), rew, False, {} def mountain_car(): @@ -64,5 +68,6 @@ def mountain_car(): Done is always returned on timestep 200 only. """ env = util.make_env_no_wrappers("MountainCar-v0") + env = util.ObsCastWrapper(env, dtype=np.float32) env = util.AbsorbAfterDoneWrapper(env) return env diff --git a/src/seals/diagnostics/largest_sum.py b/src/seals/diagnostics/largest_sum.py index 49cea74..d55021f 100644 --- a/src/seals/diagnostics/largest_sum.py +++ b/src/seals/diagnostics/largest_sum.py @@ -36,7 +36,8 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Returns vector sampled uniformly in [0, 1]**L.""" - return self.rand_state.rand(self._length) + init_state = self.rand_state.rand(self._length) + return init_state.astype(self.observation_space.dtype) def reward(self, state: np.ndarray, act: int, next_state: np.ndarray) -> float: """Returns +1.0 reward when action is the right label and 0.0 otherwise.""" diff --git a/src/seals/diagnostics/noisy_obs.py b/src/seals/diagnostics/noisy_obs.py index 1e4658c..0ce060d 100644 --- a/src/seals/diagnostics/noisy_obs.py +++ b/src/seals/diagnostics/noisy_obs.py @@ -65,4 +65,4 @@ def transition(self, state: np.ndarray, action: int) -> np.ndarray: def obs_from_state(self, state: np.ndarray) -> np.ndarray: """Returns (x, y) concatenated with Gaussian noise.""" noise_vector = self.rand_state.randn(self._noise_length) - return np.concatenate([state, noise_vector]) + return np.concatenate([state, noise_vector]).astype(np.float32) diff --git a/src/seals/diagnostics/parabola.py b/src/seals/diagnostics/parabola.py index 9678634..2a84163 100644 --- a/src/seals/diagnostics/parabola.py +++ b/src/seals/diagnostics/parabola.py @@ -43,7 +43,7 @@ def initial_state(self) -> np.ndarray: """Get state by sampling a random parabola.""" a, b, c = -1 + 2 * self.rand_state.rand(3) x, y = 0, c - return np.array([x, y, a, b, c]) + return np.array([x, y, a, b, c], dtype=self.state_space.dtype) def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float: """Negative squared vertical distance from parabola.""" @@ -56,4 +56,4 @@ def transition(self, state: np.ndarray, action: int) -> np.ndarray: x, y, a, b, c = state next_x = np.clip(x + self._x_step, -self._bounds, self._bounds) next_y = np.clip(y + action, -self._bounds, self._bounds) - return np.array([next_x, next_y, a, b, c]) + return np.array([next_x, next_y, a, b, c], dtype=self.state_space.dtype) diff --git a/src/seals/diagnostics/sort.py b/src/seals/diagnostics/sort.py index 0e08849..68d1c72 100644 --- a/src/seals/diagnostics/sort.py +++ b/src/seals/diagnostics/sort.py @@ -32,7 +32,8 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: def initial_state(self): """Sample random vector uniformly in [0, 1]**L.""" - return self.rand_state.random(size=self._length) + sample = self.rand_state.random(size=self._length) + return sample.astype(self.state_space.dtype) def reward( self, diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index e78664e..558dbc7 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -152,6 +152,25 @@ def different_seeds_same_rollout(seed1, seed2): assert same_obs == is_deterministic +def _check_obs(obs: np.ndarray, obs_space: gym.Space) -> None: + """Check obs is consistent with obs_space.""" + if obs_space.shape: + assert obs.shape == obs_space.shape + assert obs.dtype == obs_space.dtype + assert obs in obs_space + + +def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> bool: + """Sample from env and check return value is of valid type.""" + act = env.action_space.sample() + obs, rew, done, info = env.step(act) + _check_obs(obs, obs_space) + assert isinstance(rew, float) + assert isinstance(done, bool) + assert isinstance(info, dict) + return done + + def test_rollout_schema( env: gym.Env, steps_after_done: int = 10, @@ -171,26 +190,17 @@ def test_rollout_schema( """ obs_space = env.observation_space obs = env.reset() - assert obs in obs_space - - def _sample_and_check(): - act = env.action_space.sample() - obs, rew, done, info = env.step(act) - assert obs in obs_space - assert isinstance(rew, float) - assert isinstance(done, bool) - assert isinstance(info, dict) - return done + _check_obs(obs, obs_space) for _ in range(max_steps): - done = _sample_and_check() + done = _sample_and_check(env, obs_space) if done: break assert done is True, "did not get to end of episode" for _ in range(steps_after_done): - _sample_and_check() + _sample_and_check(env, obs_space) def test_premature_step(env: gym.Env, skip_fn, raises_fn) -> None: @@ -245,7 +255,7 @@ def __init__(self, episode_length: int = 5): def reset(self): """Reset method for CountingEnv.""" t, self.timestep = 0, 1 - return t + return np.array(t, dtype=self.observation_space.dtype) def step(self, action): """Step method for CountingEnv.""" @@ -257,5 +267,7 @@ def step(self, action): raise ValueError("Should reset env. Episode is over.") t, self.timestep = self.timestep, self.timestep + 1 + obs = np.array(t, dtype=self.observation_space.dtype) + rew = t * 10.0 done = t == self.episode_length - return t, t * 10, done, {} + return obs, rew, done, {} diff --git a/src/seals/util.py b/src/seals/util.py index 14522ea..398c59a 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -23,6 +23,35 @@ def step(self, action): return obs, rew, False, info +class ObsCastWrapper(gym.Wrapper): + """Cast observations to specified dtype. + + Some external environments return observations of a different type than the + declared observation space. Where possible, this should be fixed upstream, + but casting can be a viable workaround -- especially when the returned + observations are higher resolution than the observation space. + """ + + def __init__(self, env: gym.Env, dtype: np.dtype): + """Builds ObsCastWrapper. + + Args: + env: the environment to wrap. + dtype: the dtype to cast observations to. + """ + super().__init__(env) + self.dtype = dtype + + def reset(self): + """Returns reset observation, cast to self.dtype.""" + return super().reset().astype(self.dtype) + + def step(self, action): + """Returns (obs, rew, done, info) with obs cast to self.dtype.""" + obs, rew, done, info = super().step(action) + return obs.astype(self.dtype), rew, done, info + + class AbsorbAfterDoneWrapper(gym.Wrapper): """Transition into absorbing state instead of episode termination. diff --git a/tests/conftest.py b/tests/conftest.py index 7058099..bdbf6da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,8 @@ import pytest +pytest.register_assert_rewrite("seals.testing") + def pytest_addoption(parser): """Add --expensive option.""" diff --git a/tests/test_mujoco_rl.py b/tests/test_mujoco_rl.py index ee02003..ca64c68 100644 --- a/tests/test_mujoco_rl.py +++ b/tests/test_mujoco_rl.py @@ -14,12 +14,14 @@ def _eval_env( env_name: str, total_timesteps: int, -) -> Tuple[float, int]: # pragma: no cover +) -> Tuple[float, float]: # pragma: no cover """Train PPO2 for `total_timesteps` on `env_name` and evaluate returns.""" env = gym.make(env_name) model = PPO2(MlpPolicy, env) model.learn(total_timesteps=total_timesteps) - return evaluate_policy(model, env) + res = evaluate_policy(model, env) + assert isinstance(res[0], float) + return res @pytest.mark.expensive diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 151a1ac..0b18e4d 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,5 +1,8 @@ """Tests for wrapper classes.""" +import numpy as np +import pytest + from seals import util from seals.testing import envs @@ -81,3 +84,25 @@ def test_absorb_repeat_final_state(episode_length=6, n_steps=100, n_manual_reset expected_rew = t * 10.0 assert obs == expected_obs assert rew == expected_rew + + +@pytest.mark.parametrize("dtype", [np.int, np.float32, np.float64]) +def test_obs_cast(dtype: np.dtype, episode_length: int = 5): + """Check obs_cast observations are of specified dtype and not mangled. + + Test uses CountingEnv with small integers, which can be represented in + all the specified dtypes without any loss of precision. + """ + env = envs.CountingEnv(episode_length=episode_length) + env = util.ObsCastWrapper(env, dtype) + + obs = env.reset() + assert obs.dtype == dtype + assert obs == 0 + for t in range(1, episode_length + 1): + act = env.action_space.sample() + obs, rew, done, _ = env.step(act) + assert done == (t == episode_length) + assert obs.dtype == dtype + assert obs == t + assert rew == t * 10.0