Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate imitation envs to seals #58

Merged
merged 44 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
291f514
Initial version of imitation+seals merge of POMDP/MDP environments.
Rocamonde Aug 26, 2022
4996487
Bug fixes to make tests pass
Rocamonde Aug 26, 2022
3ad11c2
Linting and typing
Rocamonde Aug 26, 2022
e6cbb7f
Ran black
Rocamonde Aug 26, 2022
579ae4d
Trailing comma to make linter happy
Rocamonde Aug 27, 2022
cfe59f5
Fix absurd fight between black and flake8
Rocamonde Aug 27, 2022
38fa563
Fix array access in bash script (code_checks.sh)
Rocamonde Aug 27, 2022
ab9a46e
Added Makefile to simplify local CI checking
Rocamonde Aug 27, 2022
7046feb
Removed pytype restriction to only python 3.7
Rocamonde Aug 27, 2022
6a8e21a
Added "type: ignore" on call to numpy method with incorrect type sign…
Rocamonde Aug 27, 2022
e157bab
Fixed error on incompatible type signature due to inheritance by addi…
Rocamonde Aug 27, 2022
b2f0ae2
Added imitation examples (to be moved to a better file)
Rocamonde Aug 27, 2022
40286da
Increased max line length in linting
Rocamonde Aug 27, 2022
42fbdc5
Linting and docstrings
Rocamonde Aug 27, 2022
3a5d50b
Small fixes
Rocamonde Aug 27, 2022
bba45eb
Bug fixes
Rocamonde Aug 27, 2022
49dd9f0
Attempt to fix box boundary overflow
Rocamonde Aug 28, 2022
b7dc25f
Attempt to fix inf to int overflow
Rocamonde Aug 29, 2022
cf97099
Fix bug in ResettablePOMDP
Rocamonde Sep 6, 2022
0ac051f
Merge branch 'master' into imitation-envs-to-seals
AdamGleave Sep 9, 2022
0beb871
Remove makefile for now
Rocamonde Sep 11, 2022
11160c0
Roll back line length for now
Rocamonde Sep 11, 2022
bc46368
Fix matplotlib issue
Rocamonde Sep 11, 2022
6b85775
Remove type ignore
Rocamonde Sep 11, 2022
fccb19b
Miscellaneous improvements from review feedback
Rocamonde Sep 11, 2022
f73ca2e
Merge branch 'imitation-envs-to-seals' of github.com:HumanCompatibleA…
Rocamonde Sep 11, 2022
61b7056
Restructure imitation examples into adequate
Rocamonde Sep 11, 2022
8fbea84
Improve coverage
Rocamonde Sep 11, 2022
580f77a
Fix typo
Rocamonde Sep 11, 2022
c758d4e
Add docstring to test
Rocamonde Sep 11, 2022
4addc5b
Additional test coverage improvements
Rocamonde Sep 11, 2022
5464af7
Add docstrings to test
Rocamonde Sep 11, 2022
a822a46
Rearrange observation matrix in POMDP
Rocamonde Sep 30, 2022
0391192
Fix space constructors
Rocamonde Sep 30, 2022
9f84893
Switch to python 3.8 minimum
Rocamonde Oct 3, 2022
6d9879f
Add docstring
Rocamonde Oct 3, 2022
91b55d2
Merge remote-tracking branch 'origin' into imitation-envs-to-seals
Rocamonde Oct 3, 2022
c714e59
Improve test coverage
Rocamonde Oct 3, 2022
54cb4b5
Reorder imports
Rocamonde Oct 3, 2022
427f67d
Added docstring to tests and exceptions
Rocamonde Oct 3, 2022
b03a680
Final coverage fixes
Rocamonde Oct 3, 2022
f533812
Move cliff world registration to function
Rocamonde Oct 4, 2022
0c66584
Remove comment
Rocamonde Oct 4, 2022
94fbf49
Add docstring
Rocamonde Oct 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions Makefile

This file was deleted.

4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ strictness=long
[flake8]
docstring-convention=google
ignore = E203, W503
max-line-length = 100
max-line-length = 88

[isort]
line_length=88
Expand All @@ -38,7 +38,7 @@ inputs =
src/
tests/
setup.py
python_version >= 3.7
python_version >= 3.8

[tool:pytest]
markers =
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def get_readme() -> str:
"flake8-docstrings",
"flake8-isort",
"isort",
"matplotlib",
"mypy",
"pydocstyle",
"pytest",
Expand Down Expand Up @@ -137,7 +138,7 @@ def get_readme() -> str:
packages=find_packages("src"),
package_dir={"": "src"},
package_data={"seals": ["py.typed"]},
install_requires=["gym", "numpy", "matplotlib"],
install_requires=["gym", "numpy"],
tests_require=TESTS_REQUIRE,
extras_require={
# recommended packages for development
Expand Down
146 changes: 82 additions & 64 deletions src/seals/base_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def state(self) -> State:
@state.setter
def state(self, state: State):
"""Set current state."""
if self._cur_state is not None and self._cur_state not in self.state_space:
if state not in self.state_space:
raise ValueError(f"{state} not in {self.state_space}")
self._cur_state = state

Expand Down Expand Up @@ -130,8 +130,7 @@ def step(self, action: Action) -> Tuple[Observation, float, bool, dict]:
old_state = self.state
self.state = self.transition(self.state, action)
obs = self.obs_from_state(self.state)
if obs not in self.observation_space:
raise ValueError(f"{obs} not in {self.observation_space}")
assert obs in self.observation_space
reward = self.reward(old_state, action, self.state)
self._n_actions_taken += 1
done = self.terminal(self.state, self.n_actions_taken)
Expand Down Expand Up @@ -205,13 +204,13 @@ class BaseTabularModelPOMDP(ResettablePOMDP[int, Observation, int]):

transition_matrix: np.ndarray
reward_matrix: np.ndarray
observation_matrix: np.ndarray

state_space: spaces.Discrete

def __init__(
self,
*,
transition_matrix: np.ndarray,
observation_matrix: np.ndarray,
reward_matrix: np.ndarray,
horizon: float = np.inf,
initial_state_dist: Optional[np.ndarray] = None,
Expand All @@ -221,8 +220,6 @@ def __init__(
Args:
transition_matrix: 3-D array with transition probabilities for a
given state-action pair, of shape `(n_states,n_actions,n_states)`.
observation_matrix: 2-D array with observation probabilities for a
given state, of shape `(n_states,n_observations)`.
reward_matrix: 1-D, 2-D or 3-D array corresponding to rewards to a
given `(state, action, next_state)` triple. A 2-D array assumes
the `next_state` is not used in the reward, and a 1-D array
Expand All @@ -239,84 +236,62 @@ def __init__(
`initial_state_dist` have shapes different to specified above.
"""
# The following matrices should conform to the shapes below:
Rocamonde marked this conversation as resolved.
Show resolved Hide resolved
# transition matrix: n_states x n_actions x n_states
# reward matrix: n_states x n_actions x n_states
# OR n_states x n_actions
# OR n_states
# observation matrix: n_states x n_observations
# initial state dist: n_states
# we want to make sure that the shapes are correct

if transition_matrix.shape[0] != transition_matrix.shape[2]:
# transition matrix: n_states x n_actions x n_states
n_states = transition_matrix.shape[0]
if n_states != transition_matrix.shape[2]:
raise ValueError(
"Malformed transition_matrix:\n"
f"transition_matrix.shape: {transition_matrix.shape}\n"
f"{transition_matrix.shape[0]} != {transition_matrix.shape[2]}",
f"{n_states} != {transition_matrix.shape[2]}",
)

# reward matrix: n_states x n_actions x n_states
# OR n_states x n_actions
# OR n_states
if reward_matrix.shape != transition_matrix.shape[: len(reward_matrix.shape)]:
raise ValueError(
"transition_matrix and reward_matrix are not compatible:\n"
f"transition_matrix.shape: {transition_matrix.shape}\n"
f"reward_matrix.shape: {reward_matrix.shape}",
)

if observation_matrix.shape[0] != transition_matrix.shape[0]:
raise ValueError(
"transition_matrix and observation_matrix are not compatible:\n"
f"transition_matrix.shape[0]: {transition_matrix.shape[0]}\n"
f"observation_matrix.shape[0]: {observation_matrix.shape[0]}",
)

# initial state dist: n_states
if initial_state_dist is None:
initial_state_dist = util.one_hot_encoding(0, transition_matrix.shape[0])
initial_state_dist = util.one_hot_encoding(0, n_states)
if initial_state_dist.ndim != 1:
raise ValueError(
"initial_state_dist has multiple dimensions:\n"
f"{initial_state_dist.ndim} != 1",
)
if initial_state_dist.shape[0] != transition_matrix.shape[0]:
if initial_state_dist.shape[0] != n_states:
raise ValueError(
"transition_matrix and initial_state_dist are not compatible:\n"
f"number of states = {transition_matrix.shape[0]}\n"
f"number of states = {n_states}\n"
f"len(initial_state_dist) = {len(initial_state_dist)}",
)

self.transition_matrix = transition_matrix
self.reward_matrix = reward_matrix
self.observation_matrix = observation_matrix
self._feature_matrix = None
self.horizon = horizon
self.initial_state_dist = initial_state_dist

super().__init__(
state_space=self._construct_state_space(self.state_dim),
action_space=self._construct_action_space(self.action_dim),
observation_space=self._construct_obs_space(self.obs_dim, self.obs_dtype),
state_space=self._construct_state_space(),
action_space=self._construct_action_space(),
observation_space=self._construct_observation_space(),
)

@staticmethod
def _construct_state_space(n_states: int) -> gym.Space:
return spaces.Discrete(n_states)
def _construct_state_space(self) -> gym.Space:
return spaces.Discrete(self.state_dim)

@staticmethod
def _construct_action_space(n_actions: int) -> gym.Space:
return spaces.Discrete(n_actions)
def _construct_action_space(self) -> gym.Space:
return spaces.Discrete(self.action_dim)

@staticmethod
def _construct_obs_space(obs_dim, obs_dtype) -> gym.Space:
try:
dtype_iinfo = np.iinfo(obs_dtype)
min_val, max_val = dtype_iinfo.min, dtype_iinfo.max
except ValueError:
min_val = -np.inf
max_val = np.inf
return spaces.Box(
low=min_val,
high=max_val,
shape=(obs_dim,),
dtype=obs_dtype,
)
@abc.abstractmethod
def _construct_observation_space(self) -> gym.Space:
pass # pragma: no cover

def initial_state(self) -> int:
"""Samples from the initial state distribution."""
Expand Down Expand Up @@ -346,8 +321,6 @@ def feature_matrix(self):
"""Matrix mapping states to feature vectors."""
# Construct lazily to save memory in algorithms that don't need features.
if self._feature_matrix is None:
# TODO(juan) Space() does not have an `n` attribute (?).
# Are we hinting the wrong type?
n_states = self.state_space.n
self._feature_matrix = np.eye(n_states)
return self._feature_matrix
Expand All @@ -362,16 +335,6 @@ def action_dim(self) -> int:
"""Number of action vectors (int)."""
return self.transition_matrix.shape[1]

@property
def obs_dim(self) -> int:
"""Size of observation vectors for this MDP."""
return self.observation_matrix.shape[1]

@property
def obs_dtype(self) -> int:
"""Data type of observation vectors (e.g. np.float32)."""
return self.observation_matrix.dtype


class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray]):
"""Tabular model POMDP.
Expand All @@ -385,6 +348,50 @@ class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray]):
a vector with self.obs_dim entries.
"""

observation_matrix: np.ndarray

def __init__(
self,
*,
transition_matrix: np.ndarray,
observation_matrix: np.ndarray,
reward_matrix: np.ndarray,
horizon: float = np.inf,
initial_state_dist: Optional[np.ndarray] = None,
):
"""Initializes a tabular model POMDP."""
self.observation_matrix = observation_matrix
super().__init__(
transition_matrix=transition_matrix,
reward_matrix=reward_matrix,
horizon=horizon,
initial_state_dist=initial_state_dist,
)

# observation matrix: n_states x n_observations
if observation_matrix.shape[0] != self.state_dim:
raise ValueError(
"transition_matrix and observation_matrix are not compatible:\n"
f"transition_matrix.shape[0]: {self.state_dim}\n"
f"observation_matrix.shape[0]: {observation_matrix.shape[0]}",
)

def _construct_observation_space(self) -> gym.Space:
min_val: float
max_val: float
try:
dtype_iinfo = np.iinfo(self.obs_dtype)
min_val, max_val = dtype_iinfo.min, dtype_iinfo.max
except ValueError:
min_val = -np.inf
max_val = np.inf
return spaces.Box(
low=min_val,
high=max_val,
shape=(self.obs_dim,),
dtype=self.obs_dtype,
)

def obs_from_state(self, state: int) -> np.ndarray:
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
"""Computes observation from state."""
# Copy so it can't be mutated in-place (updates will be reflected in
Expand All @@ -393,6 +400,16 @@ def obs_from_state(self, state: int) -> np.ndarray:
assert obs.ndim == 1, obs.shape
return obs

@property
def obs_dim(self) -> int:
"""Size of observation vectors for this MDP."""
return self.observation_matrix.shape[1]

@property
def obs_dtype(self) -> int:
"""Data type of observation vectors (e.g. np.float32)."""
return self.observation_matrix.dtype


class TabularModelMDP(BaseTabularModelPOMDP[int]):
"""Tabular model MDP.
Expand Down Expand Up @@ -424,10 +441,11 @@ def __init__(
reward_matrix=reward_matrix,
horizon=horizon,
initial_state_dist=initial_state_dist,
observation_matrix=np.eye(transition_matrix.shape[0]),
)
self._observation_space = self._state_space

def obs_from_state(self, state: int) -> int:
"""Identity since observation == state in an MDP."""
return state

def _construct_observation_space(self) -> gym.Space:
return self._construct_state_space()
39 changes: 39 additions & 0 deletions src/seals/diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,42 @@
entry_point="seals.diagnostics.sort:SortEnv",
max_episode_steps=6,
)


def register_cliff_world(suffix, kwargs):
"""Register a CliffWorld with the given suffix and keyword arguments."""
gym.register(
f"seals/CliffWorld{suffix}-v0",
entry_point="seals.diagnostics.cliff_world:CliffWorldEnv",
kwargs=kwargs,
)


for width, height, horizon in [(7, 4, 9), (15, 6, 18), (100, 20, 110)]:
Rocamonde marked this conversation as resolved.
Show resolved Hide resolved
for use_xy in [False, True]:
use_xy_str = "XY" if use_xy else ""
register_cliff_world(
f"{width}x{height}{use_xy_str}",
kwargs={
"width": width,
"height": height,
"use_xy_obs": use_xy,
"horizon": horizon,
},
)

# These parameter choices are somewhat arbitrary.
# We anticipate most users will want to construct RandomTransitionEnv directly.
gym.register(
"seals/Random-v0",
entry_point="seals.diagnostics.random_trans:RandomTransitionEnv",
kwargs={
"n_states": 16,
"n_actions": 3,
"branch_factor": 2,
"horizon": 20,
"random_obs": True,
"obs_dim": 5,
"generator_seed": 42,
},
)
Loading