diff --git a/carl/envs/box2d/carl_bipedal_walker.py b/carl/envs/box2d/carl_bipedal_walker.py index a28b51c7..e84ff11d 100644 --- a/carl/envs/box2d/carl_bipedal_walker.py +++ b/carl/envs/box2d/carl_bipedal_walker.py @@ -89,6 +89,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -119,7 +120,8 @@ def __init__( state_context_features=state_context_features, dict_observation_space=dict_observation_space, context_selector=context_selector, - context_selector_kwargs=context_selector_kwargs + context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( diff --git a/carl/envs/box2d/carl_lunarlander.py b/carl/envs/box2d/carl_lunarlander.py index b67f99ea..996ede5a 100644 --- a/carl/envs/box2d/carl_lunarlander.py +++ b/carl/envs/box2d/carl_lunarlander.py @@ -113,6 +113,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, max_episode_length: int = 1000, high_gameover_penalty: bool = False, dict_observation_space: bool = False, @@ -147,7 +148,8 @@ def __init__( max_episode_length=max_episode_length, dict_observation_space=dict_observation_space, context_selector=context_selector, - context_selector_kwargs=context_selector_kwargs + context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/box2d/carl_vehicle_racing.py b/carl/envs/box2d/carl_vehicle_racing.py index f64b49d3..f5187477 100644 --- a/carl/envs/box2d/carl_vehicle_racing.py +++ b/carl/envs/box2d/carl_vehicle_racing.py @@ -194,6 +194,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -230,6 +231,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = [ k for k in DEFAULT_CONTEXT.keys() if k not in CATEGORICAL_CONTEXT_FEATURES diff --git a/carl/envs/brax/carl_ant.py b/carl/envs/brax/carl_ant.py index a817b083..b0f27b34 100644 --- a/carl/envs/brax/carl_ant.py +++ b/carl/envs/brax/carl_ant.py @@ -49,6 +49,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -78,6 +79,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/brax/carl_fetch.py b/carl/envs/brax/carl_fetch.py index 0251eda4..8d800a91 100644 --- a/carl/envs/brax/carl_fetch.py +++ b/carl/envs/brax/carl_fetch.py @@ -53,6 +53,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -81,6 +82,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/brax/carl_grasp.py b/carl/envs/brax/carl_grasp.py index f2460311..04f40f02 100644 --- a/carl/envs/brax/carl_grasp.py +++ b/carl/envs/brax/carl_grasp.py @@ -53,6 +53,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -81,6 +82,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/brax/carl_halfcheetah.py b/carl/envs/brax/carl_halfcheetah.py index 65f432df..98bfda9b 100644 --- a/carl/envs/brax/carl_halfcheetah.py +++ b/carl/envs/brax/carl_halfcheetah.py @@ -47,6 +47,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -75,6 +76,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/brax/carl_humanoid.py b/carl/envs/brax/carl_humanoid.py index e8717b60..02e6fd04 100644 --- a/carl/envs/brax/carl_humanoid.py +++ b/carl/envs/brax/carl_humanoid.py @@ -48,6 +48,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -76,6 +77,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/brax/carl_ur5e.py b/carl/envs/brax/carl_ur5e.py index 53cdf766..4653bb63 100644 --- a/carl/envs/brax/carl_ur5e.py +++ b/carl/envs/brax/carl_ur5e.py @@ -53,6 +53,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -81,6 +82,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/carl_env.py b/carl/envs/carl_env.py index dd57c941..be7e4f90 100644 --- a/carl/envs/carl_env.py +++ b/carl/envs/carl_env.py @@ -63,6 +63,8 @@ class CARLEnv(Wrapper): If the context is visible to the agent (hide_context=False), the context features are appended to the state. state_context_features specifies which of the context features are appended to the state. The default is appending all context features. + context_mask: Optional[List[str]] + Name of context features to be ignored when appending context features to the state. context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] Context selector (object of) class, e.g., can be RoundRobinSelector (default) or RandomSelector. Should subclass AbstractSelector. @@ -94,6 +96,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = None, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -104,6 +107,7 @@ def __init__( self._contexts: Optional[Dict[Any, Dict[Any, Any]]] = None # init for property self.default_context = default_context self.contexts = contexts + self.context_mask = context_mask self.hide_context = hide_context self.dict_observation_space = dict_observation_space self.cutoff = max_episode_length @@ -153,7 +157,14 @@ def __init__( json.dump(data, file, indent="\t") else: state_context_features = [] - self.state_context_features = state_context_features + else: + state_context_features = list(self.contexts[list(self.contexts.keys())[0]].keys()) + self.state_context_features: List[str] = state_context_features + # state_context_features contains the names of the context features that should be appended to the state + # However, if context_mask is set, we want to update staet_context_feature_names so that the context features + # in context_mask are not appended to the state anymore. + if self.context_mask: + self.state_context_features = [s for s in self.state_context_features if s not in self.context_mask] self.step_counter = 0 # type: int # increased in/after step self.total_timestep_counter = 0 # type: int diff --git a/carl/envs/classic_control/carl_acrobot.py b/carl/envs/classic_control/carl_acrobot.py index 89e29873..11b5924d 100644 --- a/carl/envs/classic_control/carl_acrobot.py +++ b/carl/envs/classic_control/carl_acrobot.py @@ -101,6 +101,7 @@ def __init__( default_context: Optional[Dict] = DEFAULT_CONTEXT, max_episode_length: int = 500, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -121,6 +122,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/classic_control/carl_cartpole.py b/carl/envs/classic_control/carl_cartpole.py index 7cdcc724..68387756 100644 --- a/carl/envs/classic_control/carl_cartpole.py +++ b/carl/envs/classic_control/carl_cartpole.py @@ -70,6 +70,7 @@ def __init__( default_context: Optional[Dict] = DEFAULT_CONTEXT, max_episode_length: int = 500, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -90,6 +91,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/classic_control/carl_mountaincar.py b/carl/envs/classic_control/carl_mountaincar.py index 74df161b..ef361183 100644 --- a/carl/envs/classic_control/carl_mountaincar.py +++ b/carl/envs/classic_control/carl_mountaincar.py @@ -80,6 +80,7 @@ def __init__( default_context: Optional[Dict] = DEFAULT_CONTEXT, max_episode_length: int = 200, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -110,6 +111,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/classic_control/carl_mountaincarcontinuous.py b/carl/envs/classic_control/carl_mountaincarcontinuous.py index 43b90f7a..dbfd3e50 100644 --- a/carl/envs/classic_control/carl_mountaincarcontinuous.py +++ b/carl/envs/classic_control/carl_mountaincarcontinuous.py @@ -75,6 +75,7 @@ def __init__( default_context: Optional[Dict] = DEFAULT_CONTEXT, max_episode_length: int = 999, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -105,6 +106,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/classic_control/carl_pendulum.py b/carl/envs/classic_control/carl_pendulum.py index d6a83524..6b0748aa 100644 --- a/carl/envs/classic_control/carl_pendulum.py +++ b/carl/envs/classic_control/carl_pendulum.py @@ -68,6 +68,7 @@ def __init__( default_context: Optional[Dict] = DEFAULT_CONTEXT, max_episode_length: int = 200, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -100,6 +101,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list( DEFAULT_CONTEXT.keys() diff --git a/carl/envs/dmc/carl_dmcontrol.py b/carl/envs/dmc/carl_dmcontrol.py index 0f43961d..1e3bb0bd 100644 --- a/carl/envs/dmc/carl_dmcontrol.py +++ b/carl/envs/dmc/carl_dmcontrol.py @@ -43,7 +43,6 @@ def __init__( environment_kwargs={"flat_observation": True} ) env = MujocoToGymWrapper(env) - self.context_mask = context_mask super().__init__( env=env, contexts=contexts, @@ -58,6 +57,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) # TODO check gaussian noise on context features self.whitelist_gaussian_noise = list( diff --git a/carl/envs/mario/carl_mario.py b/carl/envs/mario/carl_mario.py index 56fc82dc..949a9d12 100644 --- a/carl/envs/mario/carl_mario.py +++ b/carl/envs/mario/carl_mario.py @@ -26,6 +26,7 @@ def __init__( scale_context_features: str = "no", default_context: Optional[Dict] = DEFAULT_CONTEXT, state_context_features: Optional[List[str]] = None, + context_mask: Optional[List[str]] = None, dict_observation_space: bool = False, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, @@ -44,6 +45,7 @@ def __init__( dict_observation_space=dict_observation_space, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.levels = [] self._update_context() diff --git a/carl/envs/rna/carl_rna.py b/carl/envs/rna/carl_rna.py index ad4f94cb..57c1791f 100644 --- a/carl/envs/rna/carl_rna.py +++ b/carl/envs/rna/carl_rna.py @@ -54,6 +54,7 @@ def __init__( default_context: Optional[Dict] = DEFAULT_CONTEXT, context_selector: Optional[Union[AbstractSelector, type(AbstractSelector)]] = None, context_selector_kwargs: Optional[Dict] = None, + context_mask: Optional[List[str]] = None, ): """ @@ -98,6 +99,7 @@ def __init__( default_context=default_context, context_selector=context_selector, context_selector_kwargs=context_selector_kwargs, + context_mask=context_mask, ) self.whitelist_gaussian_noise = list(DEFAULT_CONTEXT) diff --git a/test/test_CARLEnv.py b/test/test_CARLEnv.py index 155cbcd5..9c8268be 100644 --- a/test/test_CARLEnv.py +++ b/test/test_CARLEnv.py @@ -152,6 +152,19 @@ def test_dict_observation_space(self): next_obs, reward, done, info = env.step(action=action) env.close() + def test_state_context_feature_population(self): + env = ( # noqa: F841 local variable is assigned to but never used + CARLPendulumEnv( + contexts={}, + hide_context=False, + add_gaussian_noise_to_context=False, + gaussian_noise_std_percentage=0.01, + state_context_features=None, + scale_context_features="no", + ) + ) + self.assertIsNotNone(env.state_context_features) + class TestEpisodeTermination(unittest.TestCase): def test_episode_termination(self): @@ -289,6 +302,26 @@ def test_context_feature_scaling_unknown_step(self): with self.assertRaises(ValueError): next_obs, reward, done, info = env.step(action=action) + def test_context_mask(self): + context_mask = ["dt", "g"] + env = ( # noqa: F841 local variable is assigned to but never used + CARLPendulumEnv( + contexts={}, + hide_context=False, + context_mask=context_mask, + dict_observation_space=True, + add_gaussian_noise_to_context=False, + gaussian_noise_std_percentage=0.01, + state_context_features=None, + scale_context_features="no", + ) + ) + s = env.reset() + s_c = s["context"] + forbidden_in_context = [f for f in env.state_context_features if f in context_mask] + self.assertTrue(len(s_c) == len(list(env.default_context.keys())) - 2) + self.assertTrue(len(forbidden_in_context) == 0) + class TestContextSelection(unittest.TestCase): @staticmethod