-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add score masking to seven atari environments #62
Changes from all commits
6956efd
a97839f
b2f7d96
31af01a
aef96e5
013d942
6cb74b4
15f8569
e60925e
f480835
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""Miscellaneous utilities.""" | ||
|
||
from typing import Optional, Tuple | ||
from dataclasses import dataclass | ||
from typing import List, Optional, Sequence, Tuple, Union | ||
|
||
import gym | ||
import numpy as np | ||
|
@@ -23,6 +24,67 @@ def step(self, action): | |
return obs, rew, False, info | ||
|
||
|
||
@dataclass | ||
class BoxRegion: | ||
"""A rectangular region dataclass used by MaskScoreWrapper.""" | ||
|
||
x: Tuple | ||
y: Tuple | ||
|
||
|
||
MaskedRegionSpecifier = List[BoxRegion] | ||
|
||
|
||
class MaskScoreWrapper(gym.Wrapper): | ||
Rocamonde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Mask a list of box-shaped regions in the observation to hide reward info. | ||
|
||
Intended for environments whose observations are raw pixels (like Atari | ||
environments). Used to mask regions of the observation that include information | ||
that could be used to infer the reward, like the game score or enemy ship count. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
env: gym.Env, | ||
score_regions: MaskedRegionSpecifier, | ||
fill_value: Union[float, Sequence[float]] = 0, | ||
): | ||
"""Builds MaskScoreWrapper. | ||
|
||
Args: | ||
env: The environment to wrap. | ||
score_regions: A list of box-shaped regions to mask, each denoted by | ||
a dictionary `{"x": (x0, x1), "y": (y0, y1)}`, where `x0 < x1` | ||
and `y0 < y1`. | ||
fill_value: The fill_value for the masked region. By default is black. | ||
Can support RGB colors by being a sequence of values [r, g, b]. | ||
|
||
Raises: | ||
ValueError: If a score region does not conform to the spec. | ||
""" | ||
super().__init__(env) | ||
self.fill_value = np.array(fill_value, env.observation_space.dtype) | ||
Rocamonde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.mask = np.ones(env.observation_space.shape, dtype=bool) | ||
for r in score_regions: | ||
if r.x[0] >= r.x[1] or r.y[0] >= r.y[1]: | ||
raise ValueError('Invalid region: "x" and "y" must be increasing.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice input validation! |
||
self.mask[r.x[0] : r.x[1], r.y[0] : r.y[1]] = 0 | ||
|
||
def _mask_obs(self, obs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! Thanks for adding this. Code looks cleaner now IMO. |
||
return np.where(self.mask, obs, self.fill_value) | ||
|
||
def step(self, action): | ||
"""Returns (obs, rew, done, info) with masked obs.""" | ||
obs, rew, done, info = self.env.step(action) | ||
return self._mask_obs(obs), rew, done, info | ||
|
||
def reset(self, **kwargs): | ||
"""Returns masked reset observation.""" | ||
obs = self.env.reset(**kwargs) | ||
return self._mask_obs(obs) | ||
|
||
|
||
class ObsCastWrapper(gym.Wrapper): | ||
"""Cast observations to specified dtype. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,21 @@ | |
|
||
import collections | ||
|
||
import gym | ||
import numpy as np | ||
import pytest | ||
|
||
from seals import util | ||
from seals import GYM_ATARI_ENV_SPECS, util | ||
|
||
|
||
def test_mask_score_wrapper_enforces_spec(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be nice to add a test that actually checks I won't insist on it though, the |
||
"""Test that MaskScoreWrapper enforces the spec.""" | ||
atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id) | ||
desired_error_message = 'Invalid region: "x" and "y" must be increasing.' | ||
with pytest.raises(ValueError, match=desired_error_message): | ||
util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(0, 1), y=(1, 0))]) | ||
with pytest.raises(ValueError, match=desired_error_message): | ||
util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(1, 0), y=(0, 1))]) | ||
|
||
|
||
def test_sample_distribution(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for adding the informative error message :)