Skip to content

Commit

Permalink
Merge pull request #54 from automl/context_mask_#53
Browse files Browse the repository at this point in the history
Context mask #53
  • Loading branch information
sebidoe authored Jul 1, 2022
2 parents 49bb50f + bd84c25 commit eada47e
Show file tree
Hide file tree
Showing 19 changed files with 80 additions and 4 deletions.
4 changes: 3 additions & 1 deletion carl/envs/box2d/carl_bipedal_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -119,7 +120,8 @@ def __init__(
state_context_features=state_context_features,
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,

)
self.whitelist_gaussian_noise = list(
Expand Down
4 changes: 3 additions & 1 deletion carl/envs/box2d/carl_lunarlander.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
max_episode_length: int = 1000,
high_gameover_penalty: bool = False,
dict_observation_space: bool = False,
Expand Down Expand Up @@ -147,7 +148,8 @@ def __init__(
max_episode_length=max_episode_length,
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/box2d/carl_vehicle_racing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -230,6 +231,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = [
k for k in DEFAULT_CONTEXT.keys() if k not in CATEGORICAL_CONTEXT_FEATURES
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/brax/carl_ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/brax/carl_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/brax/carl_grasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/brax/carl_halfcheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/brax/carl_humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/brax/carl_ur5e.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
13 changes: 12 additions & 1 deletion carl/envs/carl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class CARLEnv(Wrapper):
If the context is visible to the agent (hide_context=False), the context features are appended to the state.
state_context_features specifies which of the context features are appended to the state. The default is
appending all context features.
context_mask: Optional[List[str]]
Name of context features to be ignored when appending context features to the state.
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]]
Context selector (object of) class, e.g., can be RoundRobinSelector (default) or RandomSelector.
Should subclass AbstractSelector.
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = None,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand All @@ -104,6 +107,7 @@ def __init__(
self._contexts: Optional[Dict[Any, Dict[Any, Any]]] = None # init for property
self.default_context = default_context
self.contexts = contexts
self.context_mask = context_mask
self.hide_context = hide_context
self.dict_observation_space = dict_observation_space
self.cutoff = max_episode_length
Expand Down Expand Up @@ -153,7 +157,14 @@ def __init__(
json.dump(data, file, indent="\t")
else:
state_context_features = []
self.state_context_features = state_context_features
else:
state_context_features = list(self.contexts[list(self.contexts.keys())[0]].keys())
self.state_context_features: List[str] = state_context_features
# state_context_features contains the names of the context features that should be appended to the state
# However, if context_mask is set, we want to update staet_context_feature_names so that the context features
# in context_mask are not appended to the state anymore.
if self.context_mask:
self.state_context_features = [s for s in self.state_context_features if s not in self.context_mask]

self.step_counter = 0 # type: int # increased in/after step
self.total_timestep_counter = 0 # type: int
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/classic_control/carl_acrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
default_context: Optional[Dict] = DEFAULT_CONTEXT,
max_episode_length: int = 500, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand All @@ -121,6 +122,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/classic_control/carl_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
default_context: Optional[Dict] = DEFAULT_CONTEXT,
max_episode_length: int = 500, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand All @@ -90,6 +91,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/classic_control/carl_mountaincar.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
default_context: Optional[Dict] = DEFAULT_CONTEXT,
max_episode_length: int = 200, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/classic_control/carl_mountaincarcontinuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
default_context: Optional[Dict] = DEFAULT_CONTEXT,
max_episode_length: int = 999, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/classic_control/carl_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
default_context: Optional[Dict] = DEFAULT_CONTEXT,
max_episode_length: int = 200, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
DEFAULT_CONTEXT.keys()
Expand Down
2 changes: 1 addition & 1 deletion carl/envs/dmc/carl_dmcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
environment_kwargs={"flat_observation": True}
)
env = MujocoToGymWrapper(env)
self.context_mask = context_mask
super().__init__(
env=env,
contexts=contexts,
Expand All @@ -58,6 +57,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
# TODO check gaussian noise on context features
self.whitelist_gaussian_noise = list(
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/mario/carl_mario.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
scale_context_features: str = "no",
default_context: Optional[Dict] = DEFAULT_CONTEXT,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
Expand All @@ -44,6 +45,7 @@ def __init__(
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.levels = []
self._update_context()
Expand Down
2 changes: 2 additions & 0 deletions carl/envs/rna/carl_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
default_context: Optional[Dict] = DEFAULT_CONTEXT,
context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None,
context_selector_kwargs: Optional[Dict] = None,
context_mask: Optional[List[str]] = None,
):
"""
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
default_context=default_context,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(DEFAULT_CONTEXT)

Expand Down
33 changes: 33 additions & 0 deletions test/test_CARLEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ def test_dict_observation_space(self):
next_obs, reward, done, info = env.step(action=action)
env.close()

def test_state_context_feature_population(self):
env = ( # noqa: F841 local variable is assigned to but never used
CARLPendulumEnv(
contexts={},
hide_context=False,
add_gaussian_noise_to_context=False,
gaussian_noise_std_percentage=0.01,
state_context_features=None,
scale_context_features="no",
)
)
self.assertIsNotNone(env.state_context_features)


class TestEpisodeTermination(unittest.TestCase):
def test_episode_termination(self):
Expand Down Expand Up @@ -289,6 +302,26 @@ def test_context_feature_scaling_unknown_step(self):
with self.assertRaises(ValueError):
next_obs, reward, done, info = env.step(action=action)

def test_context_mask(self):
context_mask = ["dt", "g"]
env = ( # noqa: F841 local variable is assigned to but never used
CARLPendulumEnv(
contexts={},
hide_context=False,
context_mask=context_mask,
dict_observation_space=True,
add_gaussian_noise_to_context=False,
gaussian_noise_std_percentage=0.01,
state_context_features=None,
scale_context_features="no",
)
)
s = env.reset()
s_c = s["context"]
forbidden_in_context = [f for f in env.state_context_features if f in context_mask]
self.assertTrue(len(s_c) == len(list(env.default_context.keys())) - 2)
self.assertTrue(len(forbidden_in_context) == 0)


class TestContextSelection(unittest.TestCase):
@staticmethod
Expand Down

0 comments on commit eada47e

Please sign in to comment.