-
Notifications
You must be signed in to change notification settings - Fork 8.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Respect the order of keys in a Dict's observation space when flatteni…
…ng (#1748) * Respect the order of keys in a Dict's observation space when flattening Prior to this change, the order of the key/values in the observation was used instead of the order in the Dict's observation space. unflatten already uses the order specified by the Dict's observation space. * add tests for FlattenObservation
- Loading branch information
Showing
2 changed files
with
97 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Tests for the flatten observation wrapper.""" | ||
|
||
from collections import OrderedDict | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import gym | ||
from gym.spaces import Box, Dict, unflatten, flatten | ||
from gym.wrappers import FlattenObservation | ||
|
||
|
||
class FakeEnvironment(gym.Env): | ||
def __init__(self, observation_space): | ||
self.observation_space = observation_space | ||
|
||
def reset(self): | ||
self.observation = self.observation_space.sample() | ||
return self.observation | ||
|
||
|
||
OBSERVATION_SPACES = ( | ||
( | ||
Dict( | ||
OrderedDict( | ||
[ | ||
("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)), | ||
("key2", Box(shape=(), low=1, high=1, dtype=np.float32)), | ||
("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)), | ||
] | ||
) | ||
), | ||
True, | ||
), | ||
( | ||
Dict( | ||
OrderedDict( | ||
[ | ||
("key2", Box(shape=(), low=0, high=0, dtype=np.float32)), | ||
("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)), | ||
("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)), | ||
] | ||
) | ||
), | ||
True, | ||
), | ||
( | ||
Dict( | ||
{ | ||
"key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32), | ||
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32), | ||
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), | ||
} | ||
), | ||
False, | ||
), | ||
) | ||
|
||
|
||
class TestFlattenEnvironment(object): | ||
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) | ||
def test_flattened_environment(self, observation_space, ordered_values): | ||
""" | ||
make sure that flattened observations occur in the order expected | ||
""" | ||
env = FakeEnvironment(observation_space=observation_space) | ||
wrapped_env = FlattenObservation(env) | ||
flattened = wrapped_env.reset() | ||
|
||
unflattened = unflatten(env.observation_space, flattened) | ||
original = env.observation | ||
|
||
self._check_observations(original, flattened, unflattened, ordered_values) | ||
|
||
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) | ||
def test_flatten_unflatten(self, observation_space, ordered_values): | ||
""" | ||
test flatten and unflatten functions directly | ||
""" | ||
original = observation_space.sample() | ||
|
||
flattened = flatten(observation_space, original) | ||
unflattened = unflatten(observation_space, flattened) | ||
|
||
self._check_observations(original, flattened, unflattened, ordered_values) | ||
|
||
def _check_observations(self, original, flattened, unflattened, ordered_values): | ||
# make sure that unflatten(flatten(original)) == original | ||
assert set(unflattened.keys()) == set(original.keys()) | ||
for k, v in original.items(): | ||
np.testing.assert_allclose(unflattened[k], v) | ||
|
||
if ordered_values: | ||
# make sure that the values were flattened in the order they appeared in the | ||
# OrderedDict | ||
np.testing.assert_allclose(sorted(flattened), flattened) |