Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add score masking to seven atari environments #62

Merged
merged 10 commits into from
Nov 22, 2022
67 changes: 56 additions & 11 deletions src/seals/atari.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,48 @@
"""Adaptation of Atari environments for specification learning algorithms."""

from typing import Iterable
from typing import Dict, Iterable, List, Optional, Tuple

import gym

from seals.util import AutoResetWrapper, get_gym_max_episode_steps
from seals.util import AutoResetWrapper, MaskScoreWrapper, get_gym_max_episode_steps

SCORE_REGIONS: Dict[str, List[Dict[str, Tuple[int, int]]]] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You use List[Dict[str, Tuple[int, int]]] in three places in your code -- consider defining it as a type? Like:

MaskedRegionSpecifier = List[Dict[str, Tuple[int, int]]]

I'd also consider using a named tuple instead of dict to enforce that x and y are both present.

"BeamRider": [
dict(x=(5, 20), y=(45, 120)),
dict(x=(28, 40), y=(15, 40)),
],
"Breakout": [dict(x=(0, 16), y=(35, 80))],
"Enduro": [
dict(x=(163, 173), y=(55, 110)),
dict(x=(177, 188), y=(68, 107)),
],
"Pong": [dict(x=(0, 24), y=(0, 160))],
"Qbert": [dict(x=(6, 15), y=(33, 71))],
"Seaquest": [dict(x=(7, 19), y=(80, 110))],
"SpaceInvaders": [dict(x=(10, 20), y=(0, 160))],
}

def fixed_length_atari(atari_env_id: str) -> gym.Env:
"""Fixed-length variant of a given Atari environment."""
return AutoResetWrapper(gym.make(atari_env_id))

def _get_score_region(atari_env_id: str) -> Optional[List[Dict[str, Tuple[int, int]]]]:
basename = atari_env_id.split("/")[-1].split("-")[0]
basename = basename.replace("NoFrameskip", "")
return SCORE_REGIONS.get(basename)


def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env:
"""Fixed-length, optionally masked-score variant of a given Atari environment."""
env = AutoResetWrapper(gym.make(atari_env_id))

if masked:
score_region = _get_score_region(atari_env_id)
if score_region is None:
raise ValueError(
"Requested environment does not yet support masking. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for adding the informative error message :)

+ "See https://github.com/HumanCompatibleAI/seals/issues/61.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The + is unnecessary (https://docs.python.org/3/reference/lexical_analysis.html#string-literal-concatenation). It actually introduces runtime overhead, even though this is largely irrelevant and it's more of a standard style choice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done.

)
env = MaskScoreWrapper(env, score_region)

return env


def _not_ram_or_det(env_id: str) -> bool:
Expand Down Expand Up @@ -37,20 +70,32 @@ def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool:
)


def _seals_name(gym_spec: gym.envs.registration.EnvSpec) -> str:
def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str:
"""Makes a Gym ID for an Atari environment in the seals namespace."""
slash_separated = gym_spec.id.split("/")
return "seals/" + slash_separated[-1]
name = "seals/" + slash_separated[-1]

if not masked:
last_hyphen_idx = name.rfind("-")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we confident all environments will have a {name}-v{num} format? It's been the case everywhere that I've seen, but this would preclude us from registering environments without this format, and that's probably at least worth documenting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, in the _supported_atari_env method, we already only support an Atari environment if it ends with "-v4" or "-v5". So I think this is ok for now.

name = name[:last_hyphen_idx] + "-Unmasked" + name[last_hyphen_idx:]
return name


def register_atari_envs(
gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec],
) -> None:
"""Register wrapped gym Atari environments."""
"""Register masked and unmasked wrapped gym Atari environments."""
for gym_spec in gym_atari_env_specs:
gym.register(
id=_seals_name(gym_spec),
entry_point="seals.atari:fixed_length_atari",
id=_seals_name(gym_spec, masked=False),
entry_point="seals.atari:make_atari_env",
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id, masked=False),
)
if _get_score_region(gym_spec.id) is not None:
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
gym.register(
id=_seals_name(gym_spec, masked=True),
entry_point="seals.atari:make_atari_env",
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id, masked=True),
)
48 changes: 47 additions & 1 deletion src/seals/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Miscellaneous utilities."""

from typing import Optional, Tuple
from typing import Dict, List, Optional, Sequence, Tuple, Union

import gym
import numpy as np
Expand All @@ -23,6 +23,52 @@ def step(self, action):
return obs, rew, False, info


class MaskScoreWrapper(gym.Wrapper):
Rocamonde marked this conversation as resolved.
Show resolved Hide resolved
"""Mask a list of box-shaped regions in the observation to hide reward info.

Intended for environments whose observations are raw pixels (like atari
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Intended for environments whose observations are raw pixels (like atari
Intended for environments whose observations are raw pixels (like Atari

environments). Used to mask regions of the observation that include information
that could be used to infer the reward, like the game score or enemy ship count.
"""

def __init__(
self,
env: gym.Env,
score_regions: List[Dict[str, Tuple[int, int]]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If you did define a type alias this file would be the natural place to do it.)

fill_value: Union[float, Sequence[float]] = 0,
):
"""Builds MaskScoreWrapper.

Args:
env: The environment to wrap.
score_regions: A list of box-shaped regions to mask, each denoted by
a dictionary `{"x": (x0, x1), "y": (y0, y1)}`, where `x0 < x1`
and `y0 < y1`.
fill_value: The fill_value for the masked region. By default is black.
Can support RGB colors by being a sequence of values [r, g, b].
"""
super().__init__(env)
self.fill_value = np.array(fill_value, env.observation_space.dtype)
Rocamonde marked this conversation as resolved.
Show resolved Hide resolved

self.mask = np.ones(env.observation_space.shape, dtype=bool)
for r in score_regions:
assert r["x"][0] < r["x"][1] and r["y"][0] < r["y"][1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a public method (users could create their own wrapper beyond our internal usage for seals-defined environments) you should raise a ValueError instead (and add a corresponding test). Passing in the wrong values is very much a possibility.

assert is to be used when something should always behave in a certain way by virtue of the purported logic of the program. This allows catching logical bugs and reassuring code checkers. This would be fine if only our own internally defined masks could ever be used. However, when something is part of the public API and therefore contingent on user input, we cannot really assert that a user won't pass the wrong value. See https://stackoverflow.com/questions/17530627/python-assertion-style#:~:text=The%20assert%20statement%20should%20only,user%20input%20or%20the%20environment. and https://wiki.python.org/moin/UsingAssertionsEffectively

What you can do, however, is have tests that assert that our internally defined masks verify this. That, plus a with raises test on the MaskScoreWrapper API should be enough to thoroughly test this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation!

self.mask[r["x"][0] : r["x"][1], r["y"][0] : r["y"][1]] = 0

def _mask_obs(self, obs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for adding this. Code looks cleaner now IMO.

return np.where(self.mask, obs, self.fill_value)

def step(self, action):
"""Returns (obs, rew, done, info) with masked obs."""
obs, rew, done, info = self.env.step(action)
return self._mask_obs(obs), rew, done, info

def reset(self, **kwargs):
"""Returns masked reset observation."""
obs = self.env.reset(**kwargs)
return self._mask_obs(obs)


class ObsCastWrapper(gym.Wrapper):
"""Cast observations to specified dtype.

Expand Down
33 changes: 27 additions & 6 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import seals # noqa: F401 required for env registration
from seals.atari import _seals_name
from seals.atari import _get_score_region, _seals_name
from seals.testing import envs

ENV_NAMES: List[str] = [
Expand All @@ -26,7 +26,11 @@
]

ATARI_ENVS: List[str] = [
_seals_name(gym_spec) for gym_spec in seals.GYM_ATARI_ENV_SPECS
_seals_name(gym_spec, masked=False) for gym_spec in seals.GYM_ATARI_ENV_SPECS
stewy33 marked this conversation as resolved.
Show resolved Hide resolved
] + [
_seals_name(gym_spec, masked=True)
for gym_spec in seals.GYM_ATARI_ENV_SPECS
if _get_score_region(gym_spec.id) is not None
]

ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS))
Expand All @@ -46,14 +50,31 @@ def test_some_atari_envs():


def test_atari_space_invaders():
"""Tests if there's an Atari environment called space invaders."""
space_invader_environments = list(
"""Tests for masked and unmasked Atari space invaders environments."""
masked_space_invader_environments = list(
filter(
lambda name: "SpaceInvaders" in name,
lambda name: "SpaceInvaders" in name and "Unmasked" not in name,
ATARI_ENVS,
),
)
assert len(space_invader_environments) > 0
assert len(masked_space_invader_environments) > 0

unmasked_space_invader_environments = list(
filter(
lambda name: "SpaceInvaders" in name and "Unmasked" in name,
ATARI_ENVS,
),
)
assert len(unmasked_space_invader_environments) > 0


def test_atari_unmasked_env_naming():
"""Tests that all unmasked Atari envs have the appropriate name qualifier."""
noncompliant_envs = [
(_get_score_region(name) is None and "Unmasked" not in name)
for name in ATARI_ENVS
]
assert len(noncompliant_envs) == 0


@pytest.mark.parametrize("env_name", ENV_NAMES)
Expand Down