Skip to content

Commit

Permalink
More detailed schema checks for environments (#37)
Browse files Browse the repository at this point in the history
* Explicitly check shape and dtype (dtype at least is not checked by default)

* More informative test error messages

* Return correct dtype in environments

* Add tests for ObsCastWrapper

* Lint

* Black reformat
  • Loading branch information
AdamGleave authored Sep 21, 2020
1 parent 8c0d2ff commit 0b5135b
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 22 deletions.
7 changes: 6 additions & 1 deletion src/seals/classic_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(self):
high = np.array(high)
self.observation_space = spaces.Box(-high, high, dtype=np.float32)

def reset(self):
"""Reset for FixedHorizonCartPole."""
return super().reset().astype(np.float32)

def step(self, action):
"""Step function for FixedHorizonCartPole."""
with warnings.catch_warnings():
Expand All @@ -51,7 +55,7 @@ def step(self, action):
)

rew = 1.0 if state_ok else 0.0
return np.array(self.state), rew, False, {}
return np.array(self.state, dtype=np.float32), rew, False, {}


def mountain_car():
Expand All @@ -64,5 +68,6 @@ def mountain_car():
Done is always returned on timestep 200 only.
"""
env = util.make_env_no_wrappers("MountainCar-v0")
env = util.ObsCastWrapper(env, dtype=np.float32)
env = util.AbsorbAfterDoneWrapper(env)
return env
3 changes: 2 additions & 1 deletion src/seals/diagnostics/largest_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool:

def initial_state(self) -> np.ndarray:
"""Returns vector sampled uniformly in [0, 1]**L."""
return self.rand_state.rand(self._length)
init_state = self.rand_state.rand(self._length)
return init_state.astype(self.observation_space.dtype)

def reward(self, state: np.ndarray, act: int, next_state: np.ndarray) -> float:
"""Returns +1.0 reward when action is the right label and 0.0 otherwise."""
Expand Down
2 changes: 1 addition & 1 deletion src/seals/diagnostics/noisy_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ def transition(self, state: np.ndarray, action: int) -> np.ndarray:
def obs_from_state(self, state: np.ndarray) -> np.ndarray:
"""Returns (x, y) concatenated with Gaussian noise."""
noise_vector = self.rand_state.randn(self._noise_length)
return np.concatenate([state, noise_vector])
return np.concatenate([state, noise_vector]).astype(np.float32)
4 changes: 2 additions & 2 deletions src/seals/diagnostics/parabola.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def initial_state(self) -> np.ndarray:
"""Get state by sampling a random parabola."""
a, b, c = -1 + 2 * self.rand_state.rand(3)
x, y = 0, c
return np.array([x, y, a, b, c])
return np.array([x, y, a, b, c], dtype=self.state_space.dtype)

def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float:
"""Negative squared vertical distance from parabola."""
Expand All @@ -56,4 +56,4 @@ def transition(self, state: np.ndarray, action: int) -> np.ndarray:
x, y, a, b, c = state
next_x = np.clip(x + self._x_step, -self._bounds, self._bounds)
next_y = np.clip(y + action, -self._bounds, self._bounds)
return np.array([next_x, next_y, a, b, c])
return np.array([next_x, next_y, a, b, c], dtype=self.state_space.dtype)
3 changes: 2 additions & 1 deletion src/seals/diagnostics/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool:

def initial_state(self):
"""Sample random vector uniformly in [0, 1]**L."""
return self.rand_state.random(size=self._length)
sample = self.rand_state.random(size=self._length)
return sample.astype(self.state_space.dtype)

def reward(
self,
Expand Down
40 changes: 26 additions & 14 deletions src/seals/testing/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,25 @@ def different_seeds_same_rollout(seed1, seed2):
assert same_obs == is_deterministic


def _check_obs(obs: np.ndarray, obs_space: gym.Space) -> None:
"""Check obs is consistent with obs_space."""
if obs_space.shape:
assert obs.shape == obs_space.shape
assert obs.dtype == obs_space.dtype
assert obs in obs_space


def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> bool:
"""Sample from env and check return value is of valid type."""
act = env.action_space.sample()
obs, rew, done, info = env.step(act)
_check_obs(obs, obs_space)
assert isinstance(rew, float)
assert isinstance(done, bool)
assert isinstance(info, dict)
return done


def test_rollout_schema(
env: gym.Env,
steps_after_done: int = 10,
Expand All @@ -171,26 +190,17 @@ def test_rollout_schema(
"""
obs_space = env.observation_space
obs = env.reset()
assert obs in obs_space

def _sample_and_check():
act = env.action_space.sample()
obs, rew, done, info = env.step(act)
assert obs in obs_space
assert isinstance(rew, float)
assert isinstance(done, bool)
assert isinstance(info, dict)
return done
_check_obs(obs, obs_space)

for _ in range(max_steps):
done = _sample_and_check()
done = _sample_and_check(env, obs_space)
if done:
break

assert done is True, "did not get to end of episode"

for _ in range(steps_after_done):
_sample_and_check()
_sample_and_check(env, obs_space)


def test_premature_step(env: gym.Env, skip_fn, raises_fn) -> None:
Expand Down Expand Up @@ -245,7 +255,7 @@ def __init__(self, episode_length: int = 5):
def reset(self):
"""Reset method for CountingEnv."""
t, self.timestep = 0, 1
return t
return np.array(t, dtype=self.observation_space.dtype)

def step(self, action):
"""Step method for CountingEnv."""
Expand All @@ -257,5 +267,7 @@ def step(self, action):
raise ValueError("Should reset env. Episode is over.")

t, self.timestep = self.timestep, self.timestep + 1
obs = np.array(t, dtype=self.observation_space.dtype)
rew = t * 10.0
done = t == self.episode_length
return t, t * 10, done, {}
return obs, rew, done, {}
29 changes: 29 additions & 0 deletions src/seals/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,35 @@ def step(self, action):
return obs, rew, False, info


class ObsCastWrapper(gym.Wrapper):
"""Cast observations to specified dtype.
Some external environments return observations of a different type than the
declared observation space. Where possible, this should be fixed upstream,
but casting can be a viable workaround -- especially when the returned
observations are higher resolution than the observation space.
"""

def __init__(self, env: gym.Env, dtype: np.dtype):
"""Builds ObsCastWrapper.
Args:
env: the environment to wrap.
dtype: the dtype to cast observations to.
"""
super().__init__(env)
self.dtype = dtype

def reset(self):
"""Returns reset observation, cast to self.dtype."""
return super().reset().astype(self.dtype)

def step(self, action):
"""Returns (obs, rew, done, info) with obs cast to self.dtype."""
obs, rew, done, info = super().step(action)
return obs.astype(self.dtype), rew, done, info


class AbsorbAfterDoneWrapper(gym.Wrapper):
"""Transition into absorbing state instead of episode termination.
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

pytest.register_assert_rewrite("seals.testing")


def pytest_addoption(parser):
"""Add --expensive option."""
Expand Down
6 changes: 4 additions & 2 deletions tests/test_mujoco_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
def _eval_env(
env_name: str,
total_timesteps: int,
) -> Tuple[float, int]: # pragma: no cover
) -> Tuple[float, float]: # pragma: no cover
"""Train PPO2 for `total_timesteps` on `env_name` and evaluate returns."""
env = gym.make(env_name)
model = PPO2(MlpPolicy, env)
model.learn(total_timesteps=total_timesteps)
return evaluate_policy(model, env)
res = evaluate_policy(model, env)
assert isinstance(res[0], float)
return res


@pytest.mark.expensive
Expand Down
25 changes: 25 additions & 0 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Tests for wrapper classes."""

import numpy as np
import pytest

from seals import util
from seals.testing import envs

Expand Down Expand Up @@ -81,3 +84,25 @@ def test_absorb_repeat_final_state(episode_length=6, n_steps=100, n_manual_reset
expected_rew = t * 10.0
assert obs == expected_obs
assert rew == expected_rew


@pytest.mark.parametrize("dtype", [np.int, np.float32, np.float64])
def test_obs_cast(dtype: np.dtype, episode_length: int = 5):
"""Check obs_cast observations are of specified dtype and not mangled.
Test uses CountingEnv with small integers, which can be represented in
all the specified dtypes without any loss of precision.
"""
env = envs.CountingEnv(episode_length=episode_length)
env = util.ObsCastWrapper(env, dtype)

obs = env.reset()
assert obs.dtype == dtype
assert obs == 0
for t in range(1, episode_length + 1):
act = env.action_space.sample()
obs, rew, done, _ = env.step(act)
assert done == (t == episode_length)
assert obs.dtype == dtype
assert obs == t
assert rew == t * 10.0

0 comments on commit 0b5135b

Please sign in to comment.