From e07594e665574b0f03650ebaf4907c76121846c9 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sat, 28 Sep 2024 23:46:52 +0200 Subject: [PATCH] [RLlib] MultiAgentEnv API enhancements (related to defining obs-/action spaces for agents). (#47830) --- rllib/algorithms/algorithm_config.py | 209 ++++++++----- .../tests/test_algorithm_rl_module_restore.py | 28 +- .../algorithms/tests/test_worker_failures.py | 80 ++--- rllib/core/__init__.py | 2 + rllib/core/learner/learner_group.py | 12 +- .../core/learner/tests/test_learner_group.py | 25 +- rllib/core/rl_module/multi_rl_module.py | 45 ++- .../rl_module/tests/test_multi_rl_module.py | 61 ++-- rllib/env/base_env.py | 79 ----- rllib/env/env_runner.py | 6 +- rllib/env/multi_agent_env.py | 290 +++++------------- rllib/env/multi_agent_env_runner.py | 53 ++-- rllib/env/tests/test_multi_agent_env.py | 162 ++++------ rllib/env/tests/test_multi_agent_episode.py | 15 +- rllib/env/vector_env.py | 10 - rllib/env/wrappers/open_spiel.py | 4 - rllib/env/wrappers/pettingzoo_env.py | 17 - rllib/evaluation/tests/test_rollout_worker.py | 3 - .../connectors/prev_actions_prev_rewards.py | 7 +- .../envs/classes/debug_counter_env.py | 2 - rllib/examples/envs/classes/two_step_game.py | 2 - .../different_spaces_for_agents.py | 26 +- rllib/utils/pre_checks/env.py | 90 ++---- .../multi_agent_episode_buffer.py | 23 +- .../multi_agent_prioritized_episode_buffer.py | 15 +- 25 files changed, 473 insertions(+), 793 deletions(-) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 6ff95f612905..31f8940d1f50 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -982,12 +982,26 @@ def build_env_to_module_connector(self, env): f"pipeline)! Your function returned {val_}." ) + obs_space = getattr(env, "single_observation_space", env.observation_space) + if obs_space is None and self.is_multi_agent(): + obs_space = gym.spaces.Dict( + { + aid: env.get_observation_space(aid) + for aid in env.unwrapped.possible_agents + } + ) + act_space = getattr(env, "single_action_space", env.action_space) + if act_space is None and self.is_multi_agent(): + act_space = gym.spaces.Dict( + { + aid: env.get_action_space(aid) + for aid in env.unwrapped.possible_agents + } + ) pipeline = EnvToModulePipeline( + input_observation_space=obs_space, + input_action_space=act_space, connectors=custom_connectors, - input_observation_space=getattr( - env, "single_observation_space", env.observation_space - ), - input_action_space=getattr(env, "single_action_space", env.action_space), ) if self.add_default_connectors_to_env_to_module_pipeline: @@ -1048,12 +1062,26 @@ def build_module_to_env_connector(self, env): f"pipeline)! Your function returned {val_}." ) + obs_space = getattr(env, "single_observation_space", env.observation_space) + if obs_space is None and self.is_multi_agent(): + obs_space = gym.spaces.Dict( + { + aid: env.get_observation_space(aid) + for aid in env.unwrapped.possible_agents + } + ) + act_space = getattr(env, "single_action_space", env.action_space) + if act_space is None and self.is_multi_agent(): + act_space = gym.spaces.Dict( + { + aid: env.get_action_space(aid) + for aid in env.unwrapped.possible_agents + } + ) pipeline = ModuleToEnvPipeline( + input_observation_space=obs_space, + input_action_space=act_space, connectors=custom_connectors, - input_observation_space=getattr( - env, "single_observation_space", env.observation_space - ), - input_action_space=getattr(env, "single_action_space", env.action_space), ) if self.add_default_connectors_to_module_to_env_pipeline: @@ -4916,47 +4944,54 @@ def get_multi_agent_setup( # Infer observation space. if policy_spec.observation_space is None: + env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env + # Module's space is provided -> Use it as-is. if spaces is not None and pid in spaces: obs_space = spaces[pid][0] - elif env_obs_space is not None: - env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env - # Multi-agent case AND different agents have different spaces: - # Need to reverse map spaces (for the different agents) to certain - # policy IDs. - if ( - isinstance(env_unwrapped, MultiAgentEnv) - and hasattr(env_unwrapped, "_obs_space_in_preferred_format") - and env_unwrapped._obs_space_in_preferred_format - ): - obs_space = None - mapping_fn = self.policy_mapping_fn - one_obs_space = next(iter(env_obs_space.values())) - # If all obs spaces are the same anyways, just use the first - # single-agent space. - if all(s == one_obs_space for s in env_obs_space.values()): - obs_space = one_obs_space - # Otherwise, we have to compare the ModuleID with all possible - # AgentIDs and find the agent ID that matches. - elif mapping_fn: - for aid in env_unwrapped.get_agent_ids(): - # Match: Assign spaces for this agentID to the PolicyID. - if mapping_fn(aid, None, worker=None) == pid: - # Make sure, different agents that map to the same - # policy don't have different spaces. - if ( - obs_space is not None - and env_obs_space[aid] != obs_space - ): - raise ValueError( - "Two agents in your environment map to the " - "same policyID (as per your `policy_mapping" - "_fn`), however, these agents also have " - "different observation spaces!" - ) - obs_space = env_obs_space[aid] - # Otherwise, just use env's obs space as-is. + # MultiAgentEnv -> Check, whether agents have different spaces. + elif isinstance(env_unwrapped, MultiAgentEnv): + obs_space = None + mapping_fn = self.policy_mapping_fn + aids = list( + env_unwrapped.possible_agents + if hasattr(env_unwrapped, "possible_agents") + and env_unwrapped.possible_agents + else env_unwrapped.get_agent_ids() + ) + if len(aids) == 0: + one_obs_space = env_unwrapped.observation_space else: - obs_space = env_obs_space + one_obs_space = env_unwrapped.get_observation_space(aids[0]) + # If all obs spaces are the same, just use the first space. + if all( + env_unwrapped.get_observation_space(aid) == one_obs_space + for aid in aids + ): + obs_space = one_obs_space + # Need to reverse-map spaces (for the different agents) to certain + # policy IDs. We have to compare the ModuleID with all possible + # AgentIDs and find the agent ID that matches. + elif mapping_fn: + for aid in aids: + # Match: Assign spaces for this agentID to the PolicyID. + if mapping_fn(aid, None, worker=None) == pid: + # Make sure, different agents that map to the same + # policy don't have different spaces. + if ( + obs_space is not None + and env_unwrapped.get_observation_space(aid) + != obs_space + ): + raise ValueError( + "Two agents in your environment map to the " + "same policyID (as per your `policy_mapping" + "_fn`), however, these agents also have " + "different observation spaces!" + ) + obs_space = env_unwrapped.get_observation_space(aid) + # Just use env's obs space as-is. + elif env_obs_space is not None: + obs_space = env_obs_space # Space given directly in config. elif self.observation_space: obs_space = self.observation_space @@ -4972,47 +5007,53 @@ def get_multi_agent_setup( # Infer action space. if policy_spec.action_space is None: + env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env + # Module's space is provided -> Use it as-is. if spaces is not None and pid in spaces: act_space = spaces[pid][1] - elif env_act_space is not None: - env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env - # Multi-agent case AND different agents have different spaces: - # Need to reverse map spaces (for the different agents) to certain - # policy IDs. - if ( - isinstance(env_unwrapped, MultiAgentEnv) - and hasattr(env_unwrapped, "_action_space_in_preferred_format") - and env_unwrapped._action_space_in_preferred_format - ): - act_space = None - mapping_fn = self.policy_mapping_fn - one_act_space = next(iter(env_act_space.values())) - # If all action spaces are the same anyways, just use the first - # single-agent space. - if all(s == one_act_space for s in env_act_space.values()): - act_space = one_act_space - # Otherwise, we have to compare the ModuleID with all possible - # AgentIDs and find the agent ID that matches. - elif mapping_fn: - for aid in env_unwrapped.get_agent_ids(): - # Match: Assign spaces for this AgentID to the PolicyID. - if mapping_fn(aid, None, worker=None) == pid: - # Make sure, different agents that map to the same - # policy don't have different spaces. - if ( - act_space is not None - and env_act_space[aid] != act_space - ): - raise ValueError( - "Two agents in your environment map to the " - "same policyID (as per your `policy_mapping" - "_fn`), however, these agents also have " - "different action spaces!" - ) - act_space = env_act_space[aid] - # Otherwise, just use env's action space as-is. + # MultiAgentEnv -> Check, whether agents have different spaces. + elif isinstance(env_unwrapped, MultiAgentEnv): + act_space = None + mapping_fn = self.policy_mapping_fn + aids = list( + env_unwrapped.possible_agents + if hasattr(env_unwrapped, "possible_agents") + and env_unwrapped.possible_agents + else env_unwrapped.get_agent_ids() + ) + if len(aids) == 0: + one_act_space = env_unwrapped.action_space else: - act_space = env_act_space + one_act_space = env_unwrapped.get_action_space(aids[0]) + # If all obs spaces are the same, just use the first space. + if all( + env_unwrapped.get_action_space(aid) == one_act_space + for aid in aids + ): + act_space = one_act_space + # Need to reverse-map spaces (for the different agents) to certain + # policy IDs. We have to compare the ModuleID with all possible + # AgentIDs and find the agent ID that matches. + elif mapping_fn: + for aid in aids: + # Match: Assign spaces for this AgentID to the PolicyID. + if mapping_fn(aid, None, worker=None) == pid: + # Make sure, different agents that map to the same + # policy don't have different spaces. + if ( + act_space is not None + and env_unwrapped.get_action_space(aid) != act_space + ): + raise ValueError( + "Two agents in your environment map to the " + "same policyID (as per your `policy_mapping" + "_fn`), however, these agents also have " + "different action spaces!" + ) + act_space = env_unwrapped.get_action_space(aid) + # Just use env's action space as-is. + elif env_act_space is not None: + act_space = env_act_space elif self.action_space: act_space = self.action_space else: diff --git a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py index 7e261ced6381..fb0b5d2f4fee 100644 --- a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py +++ b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py @@ -72,8 +72,8 @@ def test_e2e_load_simple_multi_rl_module(self): for i in range(NUM_AGENTS): module_specs[f"policy_{i}"] = RLModuleSpec( module_class=module_class, - observation_space=env.observation_space[0], - action_space=env.action_space[0], + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), # If we want to use this externally created module in the algorithm, # we need to provide the same config as the algorithm. model_config_dict=config.model_config @@ -115,8 +115,8 @@ def test_e2e_load_complex_multi_rl_module(self): for i in range(NUM_AGENTS): module_specs[f"policy_{i}"] = RLModuleSpec( module_class=module_class, - observation_space=env.observation_space[0], - action_space=env.action_space[0], + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), # If we want to use this externally created module in the algorithm, # we need to provide the same config as the algorithm. model_config_dict=config.model_config @@ -131,8 +131,8 @@ def test_e2e_load_complex_multi_rl_module(self): # create a RLModule to load and override the "policy_1" module with module_to_swap_in = RLModuleSpec( module_class=module_class, - observation_space=env.observation_space[0], - action_space=env.action_space[0], + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), # Note, we need to pass in the default model config for the algorithm # to be able to use this module later. model_config_dict=config.model_config | {"fcnet_hiddens": [64]}, @@ -146,8 +146,8 @@ def test_e2e_load_complex_multi_rl_module(self): # and the module_to_swap_in_checkpoint module_specs["policy_1"] = RLModuleSpec( module_class=module_class, - observation_space=env.observation_space[0], - action_space=env.action_space[0], + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), model_config_dict={"fcnet_hiddens": [64]}, catalog_class=PPOCatalog, load_state_path=module_to_swap_in_path, @@ -258,8 +258,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): for i in range(num_agents): module_specs[f"policy_{i}"] = RLModuleSpec( module_class=module_class, - observation_space=env.observation_space[0], - action_space=env.action_space[0], + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), # Note, we need to pass in the default model config for the # algorithm to be able to use this module later. model_config_dict=config.model_config @@ -274,8 +274,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): # create a RLModule to load and override the "policy_1" module with module_to_swap_in = RLModuleSpec( module_class=module_class, - observation_space=env.observation_space[0], - action_space=env.action_space[0], + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), # Note, we need to pass in the default model config for the algorithm # to be able to use this module later. model_config_dict=config.model_config | {"fcnet_hiddens": [64]}, @@ -289,8 +289,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): # and the module_to_swap_in_checkpoint module_specs["policy_1"] = RLModuleSpec( module_class=module_class, - observation_space=env.observation_space[0], - action_space=env.action_space[0], + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), model_config_dict={"fcnet_hiddens": [64]}, catalog_class=PPOCatalog, load_state_path=module_to_swap_in_path, diff --git a/rllib/algorithms/tests/test_worker_failures.py b/rllib/algorithms/tests/test_worker_failures.py index 8e603694a158..2525ca307b80 100644 --- a/rllib/algorithms/tests/test_worker_failures.py +++ b/rllib/algorithms/tests/test_worker_failures.py @@ -8,6 +8,7 @@ from ray.util.state import list_actors from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.rllib.algorithms.impala import IMPALAConfig from ray.rllib.algorithms.sac.sac import SACConfig from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations @@ -157,9 +158,6 @@ def step(self, action): return self.env.step(action) - def action_space_sample(self): - return self.env.action_space.sample() - class ForwardHealthCheckToEnvWorker(SingleAgentEnvRunner): """Configuring EnvRunner to error in specific condition is hard. @@ -410,36 +408,24 @@ def test_fatal_multi_agent(self): .multi_agent(policies={"p0"}, policy_mapping_fn=lambda *a, **k: "p0"), ) - # TODO (sven): Reinstate once IMPALA/APPO support EnvRunners. - # def test_async_samples(self): - # self._do_test_fault_ignore( - # IMPALAConfig() - # .api_stack( - # enable_rl_module_and_learner=True, - # enable_env_runners_and_connector_v2=True, - # ) - # .env_runners(env_runner_cls=ForwardHealthCheckToEnvWorker) - # .resources(num_gpus=0) - # ) - - def test_sync_replay(self): + def test_async_samples(self): self._do_test_failing_ignore( - SACConfig() + IMPALAConfig() .api_stack( enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True, ) + .env_runners(env_runner_cls=ForwardHealthCheckToEnvWorker) + ) + + def test_sync_replay(self): + self._do_test_failing_ignore( + SACConfig() .environment( env_config={"action_space": gym.spaces.Box(0, 1, (2,), np.float32)} ) .env_runners(env_runner_cls=ForwardHealthCheckToEnvWorker) .reporting(min_sample_timesteps_per_iteration=1) - .training( - replay_buffer_config={"type": "EpisodeReplayBuffer"}, - # We need to set the base `lr` to `None` b/c SAC in the new stack - # has its own learning rates. - lr=None, - ) ) def test_multi_gpu(self): @@ -525,8 +511,8 @@ def test_eval_workers_parallel_to_training_multi_agent_failing_recover( ) .evaluation( evaluation_num_env_runners=1, - evaluation_parallel_to_training=True, - evaluation_duration="auto", + # evaluation_parallel_to_training=True, + # evaluation_duration="auto", ) .training(model={"fcnet_hiddens": [4]}) ) @@ -687,26 +673,26 @@ def test_modules_are_restored_on_recovered_worker(self): self.assertEqual(algo.eval_env_runner_group.num_healthy_remote_workers(), 1) self.assertEqual(algo.eval_env_runner_group.num_remote_worker_restarts(), 1) - # Let's verify that our custom module exists on both recovered workers. - # TODO (sven): Reinstate once EnvRunners moved to new get/set_state APIs (from - # get/set_weights). - # def has_test_module(w): - # return "test_module" in w.module + # Let's verify that our custom module exists on all recovered workers. + def has_test_module(w): + return "test_module" in w.module # Rollout worker has test module. - # self.assertTrue( - # all(algo.env_runner_group.foreach_worker( - # has_test_module, local_worker=False - # )) - # ) + self.assertTrue( + all( + algo.env_runner_group.foreach_worker( + has_test_module, local_env_runner=False + ) + ) + ) # Eval worker has test module. - # self.assertTrue( - # all( - # algo.eval_env_runner_group.foreach_worker( - # has_test_module, local_worker=False - # ) - # ) - # ) + self.assertTrue( + all( + algo.eval_env_runner_group.foreach_worker( + has_test_module, local_env_runner=False + ) + ) + ) def test_eval_workers_failing_recover(self): # Counter that will survive restarts. @@ -786,16 +772,6 @@ def test_worker_failing_recover_with_hanging_workers(self): # the execution of the algorithm b/c of a single heavily stalling worker. # Timeout data (batches or episodes) are discarded. SACConfig() - .api_stack( - enable_rl_module_and_learner=True, - enable_env_runner_and_connector_v2=True, - ) - .training( - replay_buffer_config={"type": "EpisodeReplayBuffer"}, - # We need to set the base `lr` to `None` b/c new stack SAC has its - # specific learning rates for actor, critic, and alpha. - lr=None, - ) .env_runners( env_runner_cls=ForwardHealthCheckToEnvWorker, num_env_runners=3, diff --git a/rllib/core/__init__.py b/rllib/core/__init__.py index 1744008e602e..67b1ced1ad28 100644 --- a/rllib/core/__init__.py +++ b/rllib/core/__init__.py @@ -12,6 +12,7 @@ COMPONENT_LEARNER_GROUP = "learner_group" COMPONENT_METRICS_LOGGER = "metrics_logger" COMPONENT_MODULE_TO_ENV_CONNECTOR = "module_to_env_connector" +COMPONENT_MULTI_RL_MODULE_SPEC = "_multi_rl_module_spec" COMPONENT_OPTIMIZER = "optimizer" COMPONENT_RL_MODULE = "rl_module" @@ -25,6 +26,7 @@ "COMPONENT_LEARNER_GROUP", "COMPONENT_METRICS_LOGGER", "COMPONENT_MODULE_TO_ENV_CONNECTOR", + "COMPONENT_MULTI_RL_MODULE_SPEC", "COMPONENT_OPTIMIZER", "COMPONENT_RL_MODULE", "DEFAULT_AGENT_ID", diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index ea792b2c8c29..bc06dae36c87 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -18,7 +18,11 @@ import ray from ray import ObjectRef -from ray.rllib.core import COMPONENT_LEARNER, COMPONENT_RL_MODULE +from ray.rllib.core import ( + COMPONENT_LEARNER, + COMPONENT_MULTI_RL_MODULE_SPEC, + COMPONENT_RL_MODULE, +) from ray.rllib.core.learner.learner import Learner from ray.rllib.core.rl_module import validate_module_id from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec @@ -788,8 +792,10 @@ def get_weights( list(module_ids), ) ] - - return self.get_state(components)[COMPONENT_LEARNER][COMPONENT_RL_MODULE] + state = self.get_state(components)[COMPONENT_LEARNER][COMPONENT_RL_MODULE] + # Remove the MultiRLModuleSpec to just get the weights. + state.pop(COMPONENT_MULTI_RL_MODULE_SPEC, None) + return state def set_weights(self, weights) -> None: """Convenience method instead of self.set_state({'learner': {'rl_module': ..}}). diff --git a/rllib/core/learner/tests/test_learner_group.py b/rllib/core/learner/tests/test_learner_group.py index 87a7343074fb..5e75d87a9e10 100644 --- a/rllib/core/learner/tests/test_learner_group.py +++ b/rllib/core/learner/tests/test_learner_group.py @@ -10,6 +10,7 @@ from ray.rllib.algorithms.ppo.tests.test_ppo_learner import FAKE_BATCH from ray.rllib.core import ( COMPONENT_LEARNER, + COMPONENT_MULTI_RL_MODULE_SPEC, COMPONENT_RL_MODULE, DEFAULT_MODULE_ID, ) @@ -469,14 +470,12 @@ def test_save_to_path_and_restore_from_path(self): # Do a single update. learner_group.update_from_batch(batch.as_multi_agent()) + weights_after_update = learner_group.get_state( + components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + )[COMPONENT_LEARNER][COMPONENT_RL_MODULE] + weights_after_update.pop(COMPONENT_MULTI_RL_MODULE_SPEC) # Weights after the update must be different from original ones. - check( - initial_weights, - learner_group.get_state( - components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE - )[COMPONENT_LEARNER][COMPONENT_RL_MODULE], - false=True, - ) + check(initial_weights, weights_after_update, false=True) # Checkpoint the learner state after 1 update for later comparison. learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name @@ -497,18 +496,18 @@ def test_save_to_path_and_restore_from_path(self): weights_after_2_updates_with_break = learner_group.get_state( components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE )[COMPONENT_LEARNER][COMPONENT_RL_MODULE] + weights_after_2_updates_with_break.pop(COMPONENT_MULTI_RL_MODULE_SPEC) learner_group.shutdown() del learner_group # Construct a new learner group and load the initial state of the learner. learner_group = config.build_learner_group(env=env) learner_group.restore_from_path(initial_learner_checkpoint_dir) - check( - initial_weights, - learner_group.get_state( - components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE - )[COMPONENT_LEARNER][COMPONENT_RL_MODULE], - ) + weights_after_restore = learner_group.get_state( + components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + )[COMPONENT_LEARNER][COMPONENT_RL_MODULE] + weights_after_restore.pop(COMPONENT_MULTI_RL_MODULE_SPEC) + check(initial_weights, weights_after_restore) # Perform 2 updates to get to the same state as the previous learners. learner_group.update_from_batch(batch.as_multi_agent()) results_2nd_without_break = learner_group.update_from_batch( diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index 6ffd6b3345f6..a4b0deedce1e 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -17,9 +17,9 @@ ValuesView, ) +from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC from ray.rllib.core.models.specs.typing import SpecType from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec - from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils import force_list from ray.rllib.utils.annotations import ( @@ -297,11 +297,9 @@ def _forward_train( ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: """Runs the forward_train pass. - TODO(avnishn, kourosh): Review type hints for forward methods. - Args: batch: The batch of multi-agent data (i.e. mapping from module ids to - SampleBaches). + individual modules' batches). Returns: The output of the forward_train pass the specified modules. @@ -314,11 +312,9 @@ def _forward_inference( ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: """Runs the forward_inference pass. - TODO(avnishn, kourosh): Review type hints for forward methods. - Args: batch: The batch of multi-agent data (i.e. mapping from module ids to - SampleBaches). + individual modules' batches). Returns: The output of the forward_inference pass the specified modules. @@ -331,11 +327,9 @@ def _forward_exploration( ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: """Runs the forward_exploration pass. - TODO(avnishn, kourosh): Review type hints for forward methods. - Args: batch: The batch of multi-agent data (i.e. mapping from module ids to - SampleBaches). + individual modules' batches). Returns: The output of the forward_exploration pass the specified modules. @@ -353,6 +347,17 @@ def get_state( ) -> StateDict: state = {} + # We store the current RLModuleSpec as well as it might have changed over time + # (modules added/removed from `self`). + if self._check_component( + COMPONENT_MULTI_RL_MODULE_SPEC, + components, + not_components, + ): + state[COMPONENT_MULTI_RL_MODULE_SPEC] = MultiRLModuleSpec.from_module( + self + ).to_dict() + for module_id, rl_module in self.get_checkpointable_components(): if self._check_component(module_id, components, not_components): state[module_id] = rl_module.get_state( @@ -376,7 +381,27 @@ def set_state(self, state: StateDict) -> None: Args: state: The state dict to set. """ + # Check the given MultiRLModuleSpec and - if there are changes in the individual + # sub-modules - apply these to this MultiRLModule. + if COMPONENT_MULTI_RL_MODULE_SPEC in state: + multi_rl_module_spec = MultiRLModuleSpec.from_dict( + state[COMPONENT_MULTI_RL_MODULE_SPEC] + ) + # Go through all of our current modules and check, whether they are listed + # in the given MultiRLModuleSpec. If not, erase them from `self`. + for module_id, module in self._rl_modules.items(): + if module_id not in multi_rl_module_spec.module_specs: + self.remove_module(module_id, raise_err_if_not_found=True) + # Go through all the modules in the given MultiRLModuleSpec and if + # they are not present in `self`, add them. + for module_id, module_spec in multi_rl_module_spec.module_specs.items(): + if module_id not in self: + self.add_module(module_id, module_spec.build(), override=False) + + # Now, set the individual states for module_id, module_state in state.items(): + if module_id == COMPONENT_MULTI_RL_MODULE_SPEC: + continue if module_id in self: self._rl_modules[module_id].set_state(module_state) diff --git a/rllib/core/rl_module/tests/test_multi_rl_module.py b/rllib/core/rl_module/tests/test_multi_rl_module.py index 3ec8f4788247..30751d8a058e 100644 --- a/rllib/core/rl_module/tests/test_multi_rl_module.py +++ b/rllib/core/rl_module/tests/test_multi_rl_module.py @@ -1,7 +1,7 @@ import tempfile import unittest -from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC, DEFAULT_MODULE_ID from ray.rllib.core.rl_module.rl_module import RLModuleSpec, RLModuleConfig from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleConfig from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule @@ -16,15 +16,15 @@ def test_from_config(self): env = env_class({"num_agents": 2}) module1 = RLModuleSpec( module_class=DiscreteBCTorchModule, - observation_space=env.observation_space[0], - action_space=env.action_space[0], + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) module2 = RLModuleSpec( module_class=DiscreteBCTorchModule, - observation_space=env.observation_space[0], - action_space=env.action_space[0], + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) @@ -42,8 +42,8 @@ def test_as_multi_rl_module(self): multi_rl_module = DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) ).as_multi_rl_module() @@ -63,15 +63,18 @@ def test_get_state_and_set_state(self): module = DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) ).as_multi_rl_module() state = module.get_state() self.assertIsInstance(state, dict) - self.assertEqual(set(state.keys()), set(module.keys())) + self.assertEqual( + set(state.keys()) - {COMPONENT_MULTI_RL_MODULE_SPEC}, + set(module.keys()), + ) self.assertEqual( set(state[DEFAULT_MODULE_ID].keys()), set(module[DEFAULT_MODULE_ID].get_state().keys()), @@ -79,13 +82,13 @@ def test_get_state_and_set_state(self): module2 = DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) ).as_multi_rl_module() state2 = module2.get_state() - check(state, state2, false=True) + check(state[DEFAULT_MODULE_ID], state2[DEFAULT_MODULE_ID], false=True) module2.set_state(state) state2_after = module2.get_state() @@ -99,8 +102,8 @@ def test_add_remove_modules(self): env = env_class({"num_agents": 2}) module = DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) ).as_multi_rl_module() @@ -109,8 +112,8 @@ def test_add_remove_modules(self): "test", DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) ), @@ -126,8 +129,8 @@ def test_add_remove_modules(self): DEFAULT_MODULE_ID, DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) ), @@ -138,8 +141,8 @@ def test_add_remove_modules(self): DEFAULT_MODULE_ID, DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) ), @@ -152,8 +155,8 @@ def test_save_to_path_and_from_checkpoint(self): env = env_class({"num_agents": 2}) module = DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) ).as_multi_rl_module() @@ -162,8 +165,8 @@ def test_save_to_path_and_from_checkpoint(self): "test", DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [32]}, ) ), @@ -172,8 +175,8 @@ def test_save_to_path_and_from_checkpoint(self): "test2", DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [128]}, ) ), @@ -203,8 +206,8 @@ def test_save_to_path_and_from_checkpoint(self): "test3", DiscreteBCTorchModule( config=RLModuleConfig( - env.observation_space[0], - env.action_space[0], + env.get_observation_space(0), + env.get_action_space(0), model_config_dict={"fcnet_hiddens": [120]}, ) ), diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index f6bd9a3f9836..c67c642e4763 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -282,36 +282,6 @@ def action_space(self) -> gym.Space: """ raise NotImplementedError - def action_space_sample(self, agent_id: list = None) -> MultiEnvDict: - """Returns a random action for each environment, and potentially each - agent in that environment. - - Args: - agent_id: List of agent ids to sample actions for. If None or empty - list, sample actions for all agents in the environment. - - Returns: - A random action for each environment. - """ - logger.warning("action_space_sample() has not been implemented") - del agent_id - return {} - - def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict: - """Returns a random observation for each environment, and potentially - each agent in that environment. - - Args: - agent_id: List of agent ids to sample actions for. If None or empty - list, sample actions for all agents in the environment. - - Returns: - A random action for each environment. - """ - logger.warning("observation_space_sample() has not been implemented") - del agent_id - return {} - def last( self, ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]: @@ -326,55 +296,6 @@ def last( logger.warning("last has not been implemented for this environment.") return {}, {}, {}, {}, {} - def observation_space_contains(self, x: MultiEnvDict) -> bool: - """Checks if the given observation is valid for each environment. - - Args: - x: Observations to check. - - Returns: - True if the observations are contained within their respective - spaces. False otherwise. - """ - return self._space_contains(self.observation_space, x) - - def action_space_contains(self, x: MultiEnvDict) -> bool: - """Checks if the given actions is valid for each environment. - - Args: - x: Actions to check. - - Returns: - True if the actions are contained within their respective - spaces. False otherwise. - """ - return self._space_contains(self.action_space, x) - - def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool: - """Check if the given space contains the observations of x. - - Args: - space: The space to if x's observations are contained in. - x: The observations to check. - - Returns: - True if the observations of x are contained in space. - """ - agents = set(self.get_agent_ids()) - for multi_agent_dict in x.values(): - for agent_id, obs in multi_agent_dict.items(): - # this is for the case where we have a single agent - # and we're checking a Vector env thats been converted to - # a BaseEnv - if agent_id == _DUMMY_AGENT_ID: - if not space.contains(obs): - return False - # for the MultiAgent env case - elif (agent_id not in agents) or (not space[agent_id].contains(obs)): - return False - - return True - # Fixed agent identifier when there is only the single agent in the env _DUMMY_AGENT_ID = "agent0" diff --git a/rllib/env/env_runner.py b/rllib/env/env_runner.py index 2d528777c0e7..af3f1a11cdac 100644 --- a/rllib/env/env_runner.py +++ b/rllib/env/env_runner.py @@ -103,11 +103,7 @@ def sample(self, **kwargs) -> Any: @abc.abstractmethod def get_spaces(self) -> Dict[str, Tuple[gym.Space, gym.Space]]: - """Returns a dict mapping ModuleIDs to 2-tuples of obs- and action space. - - The returned dict might also contain an extra key `__env__`, which maps to - a 2-tuple of the bare Env's observation- and action spaces. - """ + """Returns a dict mapping ModuleIDs to 2-tuples of obs- and action space.""" def stop(self) -> None: """Releases all resources used by this EnvRunner. diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index d3b41d49f906..40c3107fb042 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -11,6 +11,7 @@ override, PublicAPI, ) +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import ( AgentID, EnvCreator, @@ -31,39 +32,48 @@ class MultiAgentEnv(gym.Env): """An environment that hosts multiple independent agents. - Agents are identified by (string) agent ids. Note that these "agents" here - are not to be confused with RLlib Algorithms, which are also sometimes - referred to as "agents" or "RL agents". - - The preferred format for action- and observation space is a mapping from agent - ids to their individual spaces. If that is not provided, the respective methods' - observation_space_contains(), action_space_contains(), - action_space_sample() and observation_space_sample() have to be overwritten. + Agents are identified by AgentIDs (string). """ + # Optional mappings from AgentID to individual agents' spaces. + # Set this to an "exhaustive" dictionary, mapping all possible AgentIDs to + # individual agents' spaces. Alternatively, override + # `get_observation_space(agent_id=...)` and `get_action_space(agent_id=...)`, which + # is the API that RLlib uses to get individual spaces and whose default + # implementation is to simply look up `agent_id` in these dicts. + observation_spaces: Optional[Dict[AgentID, gym.Space]] = None + action_spaces: Optional[Dict[AgentID, gym.Space]] = None + + # All agents currently active in the environment. This attribute may change during + # the lifetime of the env or even during an individual episode. + agents: List[AgentID] = [] + # All agents that may appear in the environment, ever. + # This attribute should not be changed during the lifetime of this env. + possible_agents: List[AgentID] = [] + + # @OldAPIStack + observation_space: Optional[gym.Space] = None + action_space: Optional[gym.Space] = None + def __init__(self): super().__init__() - if not hasattr(self, "observation_space"): - self.observation_space = None - if not hasattr(self, "action_space"): - self.action_space = None + # @OldAPIStack if not hasattr(self, "_agent_ids"): self._agent_ids = set() - # Do the action and observation spaces map from agent ids to spaces - # for the individual agents? - if not hasattr(self, "_action_space_in_preferred_format"): - self._action_space_in_preferred_format = None - if not hasattr(self, "_obs_space_in_preferred_format"): - self._obs_space_in_preferred_format = None + # If these important attributes are not set, try to infer them. + if not self.agents: + self.agents = list(self._agent_ids) + if not self.possible_agents: + self.possible_agents = self.agents.copy() def reset( self, *, seed: Optional[int] = None, options: Optional[dict] = None, - ) -> Tuple[MultiAgentDict, MultiAgentDict]: + ) -> Tuple[MultiAgentDict, MultiAgentDict]: # type: ignore """Resets the env and returns observations from ready agents. Args: @@ -145,165 +155,45 @@ def step( """ raise NotImplementedError - def observation_space_contains(self, x: MultiAgentDict) -> bool: - """Checks if the observation space contains the given key. - - Args: - x: Observations to check. - - Returns: - True if the observation space contains the given all observations - in x. - """ - if ( - not hasattr(self, "_obs_space_in_preferred_format") - or self._obs_space_in_preferred_format is None - ): - self._obs_space_in_preferred_format = ( - self._check_if_obs_space_maps_agent_id_to_sub_space() - ) - if self._obs_space_in_preferred_format: - for key, agent_obs in x.items(): - if not self.observation_space[key].contains(agent_obs): - return False - if not all(k in self.observation_space.spaces for k in x): - if log_once("possibly_bad_multi_agent_dict_missing_agent_observations"): - logger.warning( - "You environment returns observations that are " - "MultiAgentDicts with incomplete information. " - "Meaning that they only contain information on a subset of" - " participating agents. Ignore this warning if this is " - "intended, for example if your environment is a turn-based " - "simulation." - ) - return True - - logger.warning( - "observation_space_contains() of {} has not been implemented. " - "You " - "can either implement it yourself or bring the observation " - "space into the preferred format of a mapping from agent ids " - "to their individual observation spaces. ".format(self) - ) - return True - - def action_space_contains(self, x: MultiAgentDict) -> bool: - """Checks if the action space contains the given action. - - Args: - x: Actions to check. - - Returns: - True if the action space contains all actions in x. - """ - if ( - not hasattr(self, "_action_space_in_preferred_format") - or self._action_space_in_preferred_format is None - ): - self._action_space_in_preferred_format = ( - self._check_if_action_space_maps_agent_id_to_sub_space() - ) - if self._action_space_in_preferred_format: - return all(self.action_space[agent].contains(x[agent]) for agent in x) - - if log_once("action_space_contains"): - logger.warning( - "action_space_contains() of {} has not been implemented. " - "You " - "can either implement it yourself or bring the observation " - "space into the preferred format of a mapping from agent ids " - "to their individual observation spaces. ".format(self) - ) - return True + def render(self) -> None: + """Tries to render the environment.""" - def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict: - """Returns a random action for each environment, and potentially each - agent in that environment. + # By default, do nothing. + pass - Args: - agent_ids: List of agent ids to sample actions for. If None or - empty list, sample actions for all agents in the - environment. + def get_observation_space(self, agent_id: AgentID) -> gym.Space: + if self.observation_spaces is not None: + return self.observation_spaces[agent_id] - Returns: - A random action for each environment. - """ + # @OldAPIStack behavior. if ( - not hasattr(self, "_action_space_in_preferred_format") - or self._action_space_in_preferred_format is None + isinstance(self.observation_space, gym.spaces.Dict) + and agent_id in self.observation_space.spaces ): - self._action_space_in_preferred_format = ( - self._check_if_action_space_maps_agent_id_to_sub_space() - ) - if self._action_space_in_preferred_format: - if agent_ids is None: - agent_ids = self.get_agent_ids() - samples = self.action_space.sample() - return { - agent_id: samples[agent_id] - for agent_id in agent_ids - if agent_id != "__all__" - } - logger.warning( - f"action_space_sample() of {self} has not been implemented. " - "You can either implement it yourself or bring the observation " - "space into the preferred format of a mapping from agent ids " - "to their individual observation spaces." - ) - return {} - - def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict: - """Returns a random observation from the observation space for each - agent if agent_ids is None, otherwise returns a random observation for - the agents in agent_ids. - - Args: - agent_ids: List of agent ids to sample actions for. If None or - empty list, sample actions for all agents in the - environment. + return self.observation_space[agent_id] + else: + return self.observation_space - Returns: - A random action for each environment. - """ + def get_action_space(self, agent_id: AgentID) -> gym.Space: + if self.action_spaces is not None: + return self.action_spaces[agent_id] + # @OldAPIStack behavior. if ( - not hasattr(self, "_obs_space_in_preferred_format") - or self._obs_space_in_preferred_format is None + isinstance(self.action_space, gym.spaces.Dict) + and agent_id in self.action_space.spaces ): - self._obs_space_in_preferred_format = ( - self._check_if_obs_space_maps_agent_id_to_sub_space() - ) - if self._obs_space_in_preferred_format: - if agent_ids is None: - agent_ids = self.get_agent_ids() - samples = self.observation_space.sample() - samples = {agent_id: samples[agent_id] for agent_id in agent_ids} - return samples - if log_once("observation_space_sample"): - logger.warning( - "observation_space_sample() of {} has not been implemented. " - "You " - "can either implement it yourself or bring the observation " - "space into the preferred format of a mapping from agent ids " - "to their individual observation spaces. ".format(self) - ) - return {} - - def get_agent_ids(self) -> Set[AgentID]: - """Returns a set of agent ids in the environment. - - Returns: - Set of agent ids. - """ - if not isinstance(self._agent_ids, set): - self._agent_ids = set(self._agent_ids) - return self._agent_ids + return self.action_space[agent_id] + else: + return self.action_space - def render(self) -> None: - """Tries to render the environment.""" + @property + def num_agents(self) -> int: + return len(self.agents) - # By default, do nothing. - pass + @property + def max_num_agents(self) -> int: + return len(self.possible_agents) # fmt: off # __grouping_doc_begin__ @@ -361,6 +251,16 @@ class MyMultiAgentEnv(MultiAgentEnv): # __grouping_doc_end__ # fmt: on + @OldAPIStack + @Deprecated(new="MultiAgentEnv.possible_agents", error=False) + def get_agent_ids(self) -> Set[AgentID]: + if not hasattr(self, "_agent_ids"): + self._agent_ids = set() + if not isinstance(self._agent_ids, set): + self._agent_ids = set(self._agent_ids) + # Make this backward compatible as much as possible. + return self._agent_ids if self._agent_ids else set(self.agents) + @OldAPIStack def to_base_env( self, @@ -420,22 +320,6 @@ def to_base_env( return env - def _check_if_obs_space_maps_agent_id_to_sub_space(self) -> bool: - """Checks if obs space maps from agent ids to spaces of individual agents.""" - return ( - hasattr(self, "observation_space") - and isinstance(self.observation_space, gym.spaces.Dict) - and set(self.observation_space.spaces.keys()) == self.get_agent_ids() - ) - - def _check_if_action_space_maps_agent_id_to_sub_space(self) -> bool: - """Checks if action space maps from agent ids to spaces of individual agents.""" - return ( - hasattr(self, "action_space") - and isinstance(self.action_space, gym.spaces.Dict) - and set(self.action_space.keys()) == self.get_agent_ids() - ) - @PublicAPI def make_multi_agent( @@ -491,12 +375,13 @@ def make_multi_agent( class MultiEnv(MultiAgentEnv): def __init__(self, config: EnvContext = None): - MultiAgentEnv.__init__(self) - # Note(jungong) : explicitly check for None here, because config + super().__init__() + + # Note: Explicitly check for None here, because config # can have an empty dict but meaningful data fields (worker_index, # vector_index) etc. - # TODO(jungong) : clean this up, so we are not mixing up dict fields - # with data fields. + # TODO (sven): Clean this up, so we are not mixing up dict fields + # with data fields. if config is None: config = {} num = config.pop("num_agents", 1) @@ -506,15 +391,12 @@ def __init__(self, config: EnvContext = None): self.envs = [env_name_or_creator(config) for _ in range(num)] self.terminateds = set() self.truncateds = set() - self.observation_space = gym.spaces.Dict( - {i: self.envs[i].observation_space for i in range(num)} - ) - self._obs_space_in_preferred_format = True - self.action_space = gym.spaces.Dict( - {i: self.envs[i].action_space for i in range(num)} - ) - self._action_space_in_preferred_format = True - self._agent_ids = set(range(num)) + self.observation_spaces = { + i: self.envs[i].observation_space for i in range(num) + } + self.action_spaces = {i: self.envs[i].action_space for i in range(num)} + self.agents = list(range(num)) + self.possible_agents = self.agents.copy() @override(MultiAgentEnv) def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): @@ -769,22 +651,6 @@ def observation_space(self) -> gym.spaces.Dict: def action_space(self) -> gym.Space: return self.envs[0].action_space - @override(BaseEnv) - def observation_space_contains(self, x: MultiEnvDict) -> bool: - return all(self.envs[0].observation_space_contains(val) for val in x.values()) - - @override(BaseEnv) - def action_space_contains(self, x: MultiEnvDict) -> bool: - return all(self.envs[0].action_space_contains(val) for val in x.values()) - - @override(BaseEnv) - def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict: - return {0: self.envs[0].observation_space_sample(agent_ids)} - - @override(BaseEnv) - def action_space_sample(self, agent_ids: list = None) -> MultiEnvDict: - return {0: self.envs[0].action_space_sample(agent_ids)} - @override(BaseEnv) def get_agent_ids(self) -> Set[AgentID]: return self.envs[0].get_agent_ids() diff --git a/rllib/env/multi_agent_env_runner.py b/rllib/env/multi_agent_env_runner.py index 647073a96ca3..5f1ebd184fce 100644 --- a/rllib/env/multi_agent_env_runner.py +++ b/rllib/env/multi_agent_env_runner.py @@ -14,7 +14,6 @@ ) from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec -from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -256,20 +255,12 @@ def _sample_timesteps( while ts < num_timesteps: # Act randomly. if random_actions: - # Note, to get sampled actions from all agents' action - # spaces we need to call `MultiAgentEnv.action_space_sample()`. - if self.env.unwrapped._action_space_in_preferred_format: - actions = self.env.action_space.sample() - # Otherwise, `action_space_sample()` needs to be implemented. - else: - actions = self.env.action_space_sample() - # Remove all actions for agents that had no observation. + # Only act (randomly) for those agents that had an observation. to_env = { Columns.ACTIONS: [ { - agent_id: agent_action - for agent_id, agent_action in actions.items() - if agent_id in self._episode.get_agents_to_act() + aid: self.env.get_action_space(aid).sample() + for aid in self._episode.get_agents_to_act() } ] } @@ -465,20 +456,14 @@ def _sample_episodes( while eps < num_episodes: # Act randomly. if random_actions: - # Note, to get sampled actions from all agents' action - # spaces we need to call `MultiAgentEnv.action_space_sample()`. - if self.env.unwrapped._action_space_in_preferred_format: - actions = self.env.action_space.sample() - # Otherwise, `action_space_sample()` needs to be implemented. - else: - actions = self.env.action_space_sample() - # Remove all actions for agents that had no observation. + # Only act (randomly) for those agents that had an observation. to_env = { - Columns.ACTIONS: { - agent_id: agent_action - for agent_id, agent_action in actions.items() - if agent_id in _episode.get_agents_to_act() - }, + Columns.ACTIONS: [ + { + aid: self.env.get_action_space(aid).sample() + for aid in self._episode.get_agents_to_act() + } + ] } # Compute an action using the RLModule. else: @@ -620,10 +605,9 @@ def _sample_episodes( @override(EnvRunner) def get_spaces(self): + # Return the already agent-to-module translated spaces from our connector + # pipeline. return { - INPUT_ENV_SPACES: (self.env.observation_space, self.env.action_space), - # Use the already agent-to-module translated spaces from our connector - # pipeline. **{ mid: (o, self._env_to_module.action_space[mid]) for mid, o in self._env_to_module.observation_space.spaces.items() @@ -704,6 +688,7 @@ def get_state( not_components: Optional[Union[str, Collection[str]]] = None, **kwargs, ) -> StateDict: + # Basic state dict. state = { WEIGHTS_SEQ_NO: self._weights_seq_no, NUM_ENV_STEPS_SAMPLED_LIFETIME: ( @@ -712,6 +697,7 @@ def get_state( "agent_to_module_mapping_fn": self.config.policy_mapping_fn, } + # RLModule (MultiRLModule) component. if self._check_component(COMPONENT_RL_MODULE, components, not_components): state[COMPONENT_RL_MODULE] = self.module.get_state( components=self._get_subcomponents(COMPONENT_RL_MODULE, components), @@ -720,10 +706,12 @@ def get_state( ), **kwargs, ) + # Env-to-module connector. if self._check_component( COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components ): state[COMPONENT_ENV_TO_MODULE_CONNECTOR] = self._env_to_module.get_state() + # Module-to-env connector. if self._check_component( COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components ): @@ -898,8 +886,13 @@ def _setup_metrics(self): def _new_episode(self): return MultiAgentEpisode( - observation_space=self.env.observation_space, - action_space=self.env.action_space, + observation_space={ + aid: self.env.get_observation_space(aid) + for aid in self.env.possible_agents + }, + action_space={ + aid: self.env.get_action_space(aid) for aid in self.env.possible_agents + }, agent_to_module_mapping_fn=self.config.policy_mapping_fn, ) diff --git a/rllib/env/tests/test_multi_agent_env.py b/rllib/env/tests/test_multi_agent_env.py index 153e5bc14919..9febd9cc05d6 100644 --- a/rllib/env/tests/test_multi_agent_env.py +++ b/rllib/env/tests/test_multi_agent_env.py @@ -9,7 +9,6 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.env.multi_agent_env import ( - make_multi_agent, MultiAgentEnv, MultiAgentEnvWrapper, ) @@ -39,8 +38,8 @@ class BasicMultiAgent(MultiAgentEnv): def __init__(self, num): super().__init__() - self.agents = [MockEnv(25) for _ in range(num)] - self._agent_ids = set(range(num)) + self.envs = [MockEnv(25) for _ in range(num)] + self.agents = list(range(num)) self.terminateds = set() self.truncateds = set() self.observation_space = gym.spaces.Discrete(2) @@ -55,7 +54,7 @@ def reset(self, *, seed=None, options=None): self.resetted = True self.terminateds = set() self.truncateds = set() - reset_results = [a.reset() for a in self.agents] + reset_results = [a.reset() for a in self.envs] return ( {i: oi[0] for i, oi in enumerate(reset_results)}, {i: oi[1] for i, oi in enumerate(reset_results)}, @@ -64,15 +63,15 @@ def reset(self, *, seed=None, options=None): def step(self, action_dict): obs, rew, terminated, truncated, info = {}, {}, {}, {}, {} for i, action in action_dict.items(): - obs[i], rew[i], terminated[i], truncated[i], info[i] = self.agents[i].step( + obs[i], rew[i], terminated[i], truncated[i], info[i] = self.envs[i].step( action ) if terminated[i]: self.terminateds.add(i) if truncated[i]: self.truncateds.add(i) - terminated["__all__"] = len(self.terminateds) == len(self.agents) - truncated["__all__"] = len(self.truncateds) == len(self.agents) + terminated["__all__"] = len(self.terminateds) == len(self.envs) + truncated["__all__"] = len(self.truncateds) == len(self.envs) return obs, rew, terminated, truncated, info def render(self): @@ -87,8 +86,8 @@ class EarlyDoneMultiAgent(MultiAgentEnv): def __init__(self): super().__init__() - self.agents = [MockEnv(3), MockEnv(5)] - self._agent_ids = set(range(len(self.agents))) + self.envs = [MockEnv(3), MockEnv(5)] + self.agents = list(range(len(self.envs))) self.terminateds = set() self.truncateds = set() self.last_obs = {} @@ -109,18 +108,18 @@ def reset(self, *, seed=None, options=None): self.last_truncated = {} self.last_info = {} self.i = 0 - for i, a in enumerate(self.agents): + for i, a in enumerate(self.envs): self.last_obs[i], self.last_info[i] = a.reset() self.last_rew[i] = 0 self.last_terminated[i] = False self.last_truncated[i] = False obs_dict = {self.i: self.last_obs[self.i]} info_dict = {self.i: self.last_info[self.i]} - self.i = (self.i + 1) % len(self.agents) + self.i = (self.i + 1) % len(self.envs) return obs_dict, info_dict def step(self, action_dict): - assert len(self.terminateds) != len(self.agents) + assert len(self.terminateds) != len(self.envs) for i, action in action_dict.items(): ( self.last_obs[i], @@ -128,7 +127,7 @@ def step(self, action_dict): self.last_terminated[i], self.last_truncated[i], self.last_info[i], - ) = self.agents[i].step(action) + ) = self.envs[i].step(action) obs = {self.i: self.last_obs[self.i]} rew = {self.i: self.last_rew[self.i]} terminated = {self.i: self.last_terminated[self.i]} @@ -140,9 +139,9 @@ def step(self, action_dict): if truncated[self.i]: rew[self.i] = 0 self.truncateds.add(self.i) - self.i = (self.i + 1) % len(self.agents) - terminated["__all__"] = len(self.terminateds) == len(self.agents) - 1 - truncated["__all__"] = len(self.truncateds) == len(self.agents) - 1 + self.i = (self.i + 1) % len(self.envs) + terminated["__all__"] = len(self.terminateds) == len(self.envs) - 1 + truncated["__all__"] = len(self.truncateds) == len(self.envs) - 1 return obs, rew, terminated, truncated, info @@ -151,11 +150,13 @@ class FlexAgentsMultiAgent(MultiAgentEnv): def __init__(self): super().__init__() - self.agents = {} - self._agent_ids = set() + self.envs = {} + self.agents = [] + self.possible_agents = list(range(10000)) # Absolute max. number of agents. self.agentID = 0 self.terminateds = set() self.truncateds = set() + # All agents have the exact same spaces. self.observation_space = gym.spaces.Discrete(2) self.action_space = gym.spaces.Discrete(2) self.resetted = False @@ -163,21 +164,25 @@ def __init__(self): def spawn(self): # Spawn a new agent into the current episode. agentID = self.agentID - self.agents[agentID] = MockEnv(25) - self._agent_ids.add(agentID) + self.envs[agentID] = MockEnv(25) + self.agents.append(agentID) self.agentID += 1 return agentID + def kill(self, agent_id): + del self.envs[agent_id] + self.agents.remove(agent_id) + def reset(self, *, seed=None, options=None): - self.agents = {} - self._agent_ids = set() + self.envs = {} + self.agents.clear() self.spawn() self.resetted = True self.terminateds = set() self.truncateds = set() obs = {} infos = {} - for i, a in self.agents.items(): + for i, a in self.envs.items(): obs[i], infos[i] = a.reset() return obs, infos @@ -186,7 +191,7 @@ def step(self, action_dict): obs, rew, terminated, truncated, info = {}, {}, {}, {}, {} # Apply the actions. for i, action in action_dict.items(): - obs[i], rew[i], terminated[i], truncated[i], info[i] = self.agents[i].step( + obs[i], rew[i], terminated[i], truncated[i], info[i] = self.envs[i].step( action ) if terminated[i]: @@ -196,24 +201,25 @@ def step(self, action_dict): # Sometimes, add a new agent to the episode. if random.random() > 0.75 and len(action_dict) > 0: - i = self.spawn() - obs[i], rew[i], terminated[i], truncated[i], info[i] = self.agents[i].step( - action - ) - if terminated[i]: - self.terminateds.add(i) - if truncated[i]: - self.truncateds.add(i) + aid = self.spawn() + obs[aid], rew[aid], terminated[aid], truncated[aid], info[aid] = self.envs[ + aid + ].step(action) + if terminated[aid]: + self.terminateds.add(aid) + if truncated[aid]: + self.truncateds.add(aid) # Sometimes, kill an existing agent. - if len(self.agents) > 1 and random.random() > 0.25: - keys = list(self.agents.keys()) - key = random.choice(keys) - terminated[key] = True - del self.agents[key] - - terminated["__all__"] = len(self.terminateds) == len(self.agents) - truncated["__all__"] = len(self.truncateds) == len(self.agents) + if len(self.envs) > 1 and random.random() > 0.25: + keys = list(self.envs.keys()) + aid = random.choice(keys) + self.kill(aid) + terminated[aid] = True + self.terminateds.add(aid) + + terminated["__all__"] = len(self.terminateds) == len(self.envs) + truncated["__all__"] = len(self.truncateds) == len(self.envs) return obs, rew, terminated, truncated, info @@ -229,9 +235,8 @@ class SometimesZeroAgentsMultiAgent(MultiAgentEnv): def __init__(self, num=3): super().__init__() - self.num_agents = num - self.agents = [MockEnv(25) for _ in range(self.num_agents)] - self._agent_ids = set(range(self.num_agents)) + self.agents = list(range(num)) + self.envs = [MockEnv(25) for _ in range(self.num_agents)] self._observations = {} self._infos = {} self.terminateds = set() @@ -245,7 +250,7 @@ def reset(self, *, seed=None, options=None): self._observations = {} self._infos = {} for aid in self._get_random_agents(): - self._observations[aid], self._infos[aid] = self.agents[aid].reset() + self._observations[aid], self._infos[aid] = self.envs[aid].reset() return self._observations, self._infos def step(self, action_dict): @@ -258,7 +263,7 @@ def step(self, action_dict): terminated[aid], truncated[aid], self._infos[aid], - ) = self.agents[aid].step(action) + ) = self.envs[aid].step(action) if terminated[aid]: self.terminateds.add(aid) if truncated[aid]: @@ -306,10 +311,10 @@ def __init__(self, num, increment_obs=False): super().__init__() if increment_obs: # Observations are 0, 1, 2, 3... etc. as time advances - self.agents = [MockEnv2(5) for _ in range(num)] + self.envs = [MockEnv2(5) for _ in range(num)] else: # Observations are all zeros - self.agents = [MockEnv(5) for _ in range(num)] + self.envs = [MockEnv(5) for _ in range(num)] self._agent_ids = set(range(num)) self.terminateds = set() self.truncateds = set() @@ -334,7 +339,7 @@ def reset(self, *, seed=None, options=None): self.last_truncated = {} self.last_info = {} self.i = 0 - for i, a in enumerate(self.agents): + for i, a in enumerate(self.envs): self.last_obs[i], self.last_info[i] = a.reset() self.last_rew[i] = 0 self.last_terminated[i] = False @@ -345,7 +350,7 @@ def reset(self, *, seed=None, options=None): return obs_dict, info_dict def step(self, action_dict): - assert len(self.terminateds) != len(self.agents) + assert len(self.terminateds) != len(self.envs) for i, action in action_dict.items(): ( self.last_obs[i], @@ -353,7 +358,7 @@ def step(self, action_dict): self.last_terminated[i], self.last_truncated[i], self.last_info[i], - ) = self.agents[i].step(action) + ) = self.envs[i].step(action) obs = {self.i: self.last_obs[self.i]} rew = {self.i: self.last_rew[self.i]} terminated = {self.i: self.last_terminated[self.i]} @@ -365,8 +370,8 @@ def step(self, action_dict): if truncated[self.i]: self.truncateds.add(self.i) self.i = (self.i + 1) % self.num - terminated["__all__"] = len(self.terminateds) == len(self.agents) - truncated["__all__"] = len(self.truncateds) == len(self.agents) + terminated["__all__"] = len(self.terminateds) == len(self.envs) + truncated["__all__"] = len(self.truncateds) == len(self.envs) return obs, rew, terminated, truncated, info @@ -781,7 +786,6 @@ def get_initial_state(self): return [{}] # empty dict def is_recurrent(self): - # TODO: avnishn automatically infer this. return True ev = RolloutWorker( @@ -806,58 +810,6 @@ def is_recurrent(self): check(batch["state_in_0"][i], h) check(batch["state_out_0"][i], h) - def test_space_in_preferred_format(self): - env = NestedMultiAgentEnv() - action_space_in_preferred_format = ( - env._check_if_action_space_maps_agent_id_to_sub_space() - ) - obs_space_in_preferred_format = ( - env._check_if_obs_space_maps_agent_id_to_sub_space() - ) - assert action_space_in_preferred_format, "Act space is not in preferred format." - assert obs_space_in_preferred_format, "Obs space is not in preferred format." - - env2 = make_multi_agent("CartPole-v1")() - action_spaces_in_preferred_format = ( - env2._check_if_action_space_maps_agent_id_to_sub_space() - ) - obs_space_in_preferred_format = ( - env2._check_if_obs_space_maps_agent_id_to_sub_space() - ) - assert ( - action_spaces_in_preferred_format - ), "Action space should be in preferred format but isn't." - assert ( - obs_space_in_preferred_format - ), "Observation space should be in preferred format but isn't." - - def test_spaces_sample_contain_in_preferred_format(self): - env = NestedMultiAgentEnv() - # this environment has spaces that are in the preferred format - # for multi-agent environments where the spaces are dict spaces - # mapping agent-ids to sub-spaces - obs = env.observation_space_sample() - assert env.observation_space_contains( - obs - ), "Observation space does not contain obs" - - action = env.action_space_sample() - assert env.action_space_contains(action), "Action space does not contain action" - - def test_spaces_sample_contain_not_in_preferred_format(self): - env = make_multi_agent("CartPole-v1")({"num_agents": 2}) - # this environment has spaces that are not in the preferred format - # for multi-agent environments where the spaces not in the preferred - # format, users must override the observation_space_contains, - # action_space_contains observation_space_sample, - # and action_space_sample methods in order to do proper checks - obs = env.observation_space_sample() - assert env.observation_space_contains( - obs - ), "Observation space does not contain obs" - action = env.action_space_sample() - assert env.action_space_contains(action), "Action space does not contain action" - if __name__ == "__main__": import pytest diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index 15dde7e56ca0..a0189b092339 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -2,7 +2,7 @@ import numpy as np import unittest -from typing import List, Optional, Tuple +from typing import Optional, Tuple import ray from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -124,12 +124,6 @@ def step( return obs, reward, is_terminated, is_truncated, info - def action_space_sample(self, agent_ids: List[str] = None) -> MultiAgentDict: - # Actually not used at this stage. - return { - agent_id: self.action_space[agent_id].sample() for agent_id in agent_ids - } - # TODO (simon): Test `get_state()` and `from_state()`. class TestMultiAgentEpisode(unittest.TestCase): @@ -229,7 +223,6 @@ def test_init(self): ] action = {agent_id: i + 1 for agent_id in agents_to_step_next} - # action = env.action_space_sample(agents_stepped) obs, reward, terminated, truncated, info = env.step(action) # If "agent_0" is part of the reset obs, it steps in the first ts. @@ -270,7 +263,7 @@ def test_init(self): self.assertTrue(episode.agent_episodes["agent_1"].is_terminated) self.assertTrue(episode.agent_episodes["agent_5"].is_terminated) # Assert that the other agents are neither terminated nor truncated. - for agent_id in env.get_agent_ids(): + for agent_id in env.agents: if agent_id != "agent_1" and agent_id != "agent_5": self.assertFalse(episode.agent_episodes[agent_id].is_done) @@ -362,7 +355,7 @@ def test_add_env_reset(self): self.assertTrue(episode.env_t == episode.env_t_started == 0) # Assert that the agents with initial observations have their single-agent # episodes in place. - for agent_id in env.get_agent_ids(): + for agent_id in env.agents: # Ensure that all agents have a single env_ts=0 -> agent_ts=0 # entry in their env- to agent-timestep mappings. if agent_id in obs: @@ -3440,7 +3433,7 @@ def test_get_sample_batch(self): self.assertTrue(batch[agent_id]["truncateds"][-1]) # Finally, test that an empty episode, gives an empty batch. - episode = MultiAgentEpisode(agent_ids=env.get_agent_ids()) + episode = MultiAgentEpisode(agent_ids=env.agents) # Convert now to sample batch. batch = episode.get_sample_batch() # Ensure that this batch is empty. diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 16c0d82d569d..c3e0896ba05e 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -518,16 +518,6 @@ def observation_space(self) -> gym.Space: def action_space(self) -> gym.Space: return self._action_space - @override(BaseEnv) - def action_space_sample(self, agent_id: list = None) -> MultiEnvDict: - del agent_id - return {0: {_DUMMY_AGENT_ID: self._action_space.sample()}} - - @override(BaseEnv) - def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict: - del agent_id - return {0: {_DUMMY_AGENT_ID: self._observation_space.sample()}} - @override(BaseEnv) def get_agent_ids(self) -> Set[AgentID]: return {_DUMMY_AGENT_ID} diff --git a/rllib/env/wrappers/open_spiel.py b/rllib/env/wrappers/open_spiel.py index f18dc675bf24..1bc7ba119e68 100644 --- a/rllib/env/wrappers/open_spiel.py +++ b/rllib/env/wrappers/open_spiel.py @@ -20,10 +20,6 @@ def __init__(self, env): # Stores the current open-spiel game state. self.state = None - # Extract observation- and action spaces from game. - self._obs_space_in_preferred_format = True - self._action_space_in_preferred_format = True - self.observation_space = gym.spaces.Dict( { aid: gym.spaces.Box( diff --git a/rllib/env/wrappers/pettingzoo_env.py b/rllib/env/wrappers/pettingzoo_env.py index 31627b948a83..f7ee4cf4d6b2 100644 --- a/rllib/env/wrappers/pettingzoo_env.py +++ b/rllib/env/wrappers/pettingzoo_env.py @@ -4,7 +4,6 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.typing import MultiAgentDict @PublicAPI @@ -128,23 +127,9 @@ def __init__(self, env): self.observation_space = gym.spaces.Dict( {aid: self.env.observation_space(aid) for aid in self._agent_ids} ) - self._obs_space_in_preferred_format = True self.action_space = gym.spaces.Dict( {aid: self.env.action_space(aid) for aid in self._agent_ids} ) - self._action_space_in_preferred_format = True - - def observation_space_sample(self, agent_ids: list = None) -> MultiAgentDict: - sample = self.observation_space.sample() - if agent_ids is None: - return sample - return {aid: sample[aid] for aid in agent_ids} - - def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict: - sample = self.action_space.sample() - if agent_ids is None: - return sample - return {aid: sample[aid] for aid in agent_ids} def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): info = self.env.reset(seed=seed, options=options) @@ -204,11 +189,9 @@ def __init__(self, env): self.observation_space = gym.spaces.Dict( {aid: self.par_env.observation_space(aid) for aid in self._agent_ids} ) - self._obs_space_in_preferred_format = True self.action_space = gym.spaces.Dict( {aid: self.par_env.action_space(aid) for aid in self._agent_ids} ) - self._action_space_in_preferred_format = True def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): obs, info = self.par_env.reset(seed=seed, options=options) diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 145f4695f849..42f4813885f3 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -897,9 +897,6 @@ class MockMultiAgentEnv(MultiAgentEnv): """A mock testing MultiAgentEnv that doesn't call super.__init__().""" def __init__(self): - # Intentinoally don't call super().__init__(), - # so this env doesn't have - # `self._[action|observation]_space_in_preferred_format`attributes. self.observation_space = gym.spaces.Discrete(2) self.action_space = gym.spaces.Discrete(2) diff --git a/rllib/examples/connectors/prev_actions_prev_rewards.py b/rllib/examples/connectors/prev_actions_prev_rewards.py index 390f769be16d..2668dfb9f4c2 100644 --- a/rllib/examples/connectors/prev_actions_prev_rewards.py +++ b/rllib/examples/connectors/prev_actions_prev_rewards.py @@ -100,6 +100,7 @@ parser = add_rllib_example_script_args( default_reward=200.0, default_timesteps=1000000, default_iters=2000 ) +parser.set_defaults(enable_new_api_stack=True) parser.add_argument("--n-prev-rewards", type=int, default=1) parser.add_argument("--n-prev-actions", type=int, default=1) @@ -107,22 +108,16 @@ if __name__ == "__main__": args = parser.parse_args() - assert ( - args.enable_new_api_stack - ), "Must set --enable-new-api-stack when running this script!" - # Define our custom connector pipelines. def _env_to_module(env): # Create the env-to-module connector pipeline. return [ - # AddObservationsFromEpisodesToBatch(), PrevActionsPrevRewards( multi_agent=args.num_agents > 0, n_prev_rewards=args.n_prev_rewards, n_prev_actions=args.n_prev_actions, ), FlattenObservations(multi_agent=args.num_agents > 0), - # WriteObservationsToEpisodes(), ] # Register our environment with tune. diff --git a/rllib/examples/envs/classes/debug_counter_env.py b/rllib/examples/envs/classes/debug_counter_env.py index 69dc0870f62a..404833e18d8d 100644 --- a/rllib/examples/envs/classes/debug_counter_env.py +++ b/rllib/examples/envs/classes/debug_counter_env.py @@ -50,7 +50,6 @@ def __init__(self, config): for aid in range(self.num_agents) } ) - self._obs_space_in_preferred_format = True # Actions are always: # (episodeID, envID) as floats. @@ -60,7 +59,6 @@ def __init__(self, config): for aid in range(self.num_agents) } ) - self._action_space_in_preferred_format = True self.timesteps = [0] * self.num_agents self.terminateds = set() diff --git a/rllib/examples/envs/classes/two_step_game.py b/rllib/examples/envs/classes/two_step_game.py index bd40b03a0b08..1bf606f7b405 100644 --- a/rllib/examples/envs/classes/two_step_game.py +++ b/rllib/examples/envs/classes/two_step_game.py @@ -114,9 +114,7 @@ def __init__(self, env_config): act_space=tuple_act_space, ) self.observation_space = Dict({"agents": self.env.observation_space}) - self._obs_space_in_preferred_format = True self.action_space = Dict({"agents": self.env.action_space}) - self._action_space_in_preferred_format = True def reset(self, *, seed=None, options=None): return self.env.reset(seed=seed, options=options) diff --git a/rllib/examples/multi_agent/different_spaces_for_agents.py b/rllib/examples/multi_agent/different_spaces_for_agents.py index ec543de185fc..7331a3e3aadc 100644 --- a/rllib/examples/multi_agent/different_spaces_for_agents.py +++ b/rllib/examples/multi_agent/different_spaces_for_agents.py @@ -42,37 +42,33 @@ class BasicMultiAgentMultiSpaces(MultiAgentEnv): """ def __init__(self, config=None): - self.agents = {"agent0", "agent1"} - self._agent_ids = set(self.agents) + self.agents = ["agent0", "agent1"] self.terminateds = set() self.truncateds = set() # Provide full (preferred format) observation- and action-spaces as Dicts # mapping agent IDs to the individual agents' spaces. - self._obs_space_in_preferred_format = True - self.observation_space = gym.spaces.Dict( - { - "agent0": gym.spaces.Box(low=-1.0, high=1.0, shape=(10,)), - "agent1": gym.spaces.Box(low=-1.0, high=1.0, shape=(20,)), - } - ) - self._action_space_in_preferred_format = True - self.action_space = gym.spaces.Dict( - {"agent0": gym.spaces.Discrete(2), "agent1": gym.spaces.Discrete(3)} - ) + self.observation_spaces = { + "agent0": gym.spaces.Box(low=-1.0, high=1.0, shape=(10,)), + "agent1": gym.spaces.Box(low=-1.0, high=1.0, shape=(20,)), + } + self.action_spaces = { + "agent0": gym.spaces.Discrete(2), + "agent1": gym.spaces.Discrete(3), + } super().__init__() def reset(self, *, seed=None, options=None): self.terminateds = set() self.truncateds = set() - return {i: self.observation_space[i].sample() for i in self.agents}, {} + return {i: self.get_observation_space(i).sample() for i in self.agents}, {} def step(self, action_dict): obs, rew, terminated, truncated, info = {}, {}, {}, {}, {} for i, action in action_dict.items(): - obs[i] = self.observation_space[i].sample() + obs[i] = self.get_observation_space(i).sample() rew[i] = 0.0 terminated[i] = False truncated[i] = False diff --git a/rllib/utils/pre_checks/env.py b/rllib/utils/pre_checks/env.py index 0f74d9a64ffc..7055f08695fb 100644 --- a/rllib/utils/pre_checks/env.py +++ b/rllib/utils/pre_checks/env.py @@ -1,6 +1,5 @@ """Common pre-checks for all RLlib experiments.""" import logging -from copy import copy from typing import TYPE_CHECKING, Set import gymnasium as gym @@ -33,8 +32,6 @@ def check_multiagent_environments(env: "MultiAgentEnv") -> None: hasattr(env, "observation_space") and hasattr(env, "action_space") and hasattr(env, "_agent_ids") - and hasattr(env, "_obs_space_in_preferred_format") - and hasattr(env, "_action_space_in_preferred_format") ): if log_once("ma_env_super_ctor_called"): logger.warning( @@ -55,49 +52,15 @@ def check_multiagent_environments(env: "MultiAgentEnv") -> None: ) from e reset_obs, reset_infos = obs_and_infos - sampled_obs = env.observation_space_sample() _check_if_element_multi_agent_dict(env, reset_obs, "reset()") + + sampled_action = { + aid: env.get_action_space(aid).sample() for aid in reset_obs.keys() + } _check_if_element_multi_agent_dict( - env, sampled_obs, "env.observation_space_sample()" + env, sampled_action, "get_action_space(agent_id=..).sample()" ) - try: - env.observation_space_contains(reset_obs) - except Exception as e: - raise ValueError( - "Your observation_space_contains function has some error " - ) from e - - if not env.observation_space_contains(reset_obs): - error = ( - _not_contained_error("env.reset", "observation") - + f"\n\n reset_obs: {reset_obs}\n\n env.observation_space_sample():" - f" {sampled_obs}\n\n " - ) - raise ValueError(error) - - if not env.observation_space_contains(sampled_obs): - error = ( - _not_contained_error("observation_space_sample", "observation") - + f"\n\n env.observation_space_sample():" - f" {sampled_obs}\n\n " - ) - raise ValueError(error) - - sampled_action = env.action_space_sample(list(reset_obs.keys())) - _check_if_element_multi_agent_dict(env, sampled_action, "action_space_sample") - try: - env.action_space_contains(sampled_action) - except Exception as e: - raise ValueError("Your action_space_contains function has some error ") from e - - if not env.action_space_contains(sampled_action): - error = ( - _not_contained_error("action_space_sample", "action") - + f"\n\n sampled_action {sampled_action}\n\n" - ) - raise ValueError(error) - try: results = env.step(sampled_action) except Exception as e: @@ -113,22 +76,14 @@ def check_multiagent_environments(env: "MultiAgentEnv") -> None: _check_if_element_multi_agent_dict(env, done, "step, done") _check_if_element_multi_agent_dict(env, truncated, "step, truncated") _check_if_element_multi_agent_dict(env, info, "step, info", allow_common=True) - _check_reward( - {"dummy_env_id": reward}, base_env=True, agent_ids=env.get_agent_ids() - ) + _check_reward({"dummy_env_id": reward}, base_env=True, agent_ids=env.agents) _check_done_and_truncated( {"dummy_env_id": done}, {"dummy_env_id": truncated}, base_env=True, - agent_ids=env.get_agent_ids(), + agent_ids=env.agents, ) - _check_info({"dummy_env_id": info}, base_env=True, agent_ids=env.get_agent_ids()) - if not env.observation_space_contains(next_obs): - error = ( - _not_contained_error("env.step(sampled_action)", "observation") - + f":\n\n next_obs: {next_obs} \n\n sampled_obs: {sampled_obs}" - ) - raise ValueError(error) + _check_info({"dummy_env_id": info}, base_env=True, agent_ids=env.agents) def _check_reward(reward, base_env=False, agent_ids=None): @@ -152,8 +107,8 @@ def _check_reward(reward, base_env=False, agent_ids=None): if not (agent_id in agent_ids or agent_id == "__all__"): error = ( f"Your reward dictionary must have agent ids that belong to " - f"the environment. Agent_ids recieved from " - f"env.get_agent_ids() are: {agent_ids}" + f"the environment. AgentIDs received from " + f"env.agents are: {agent_ids}" ) raise ValueError(error) elif not ( @@ -185,8 +140,8 @@ def _check_done_and_truncated(done, truncated, base_env=False, agent_ids=None): if not (agent_id in agent_ids or agent_id == "__all__"): error = ( f"Your `{what}s` dictionary must have agent ids that " - f"belong to the environment. Agent_ids recieved from " - f"env.get_agent_ids() are: {agent_ids}" + f"belong to the environment. AgentIDs received from " + f"env.agents are: {agent_ids}" ) raise ValueError(error) elif not isinstance(data, (bool, np.bool_)): @@ -213,8 +168,8 @@ def _check_info(info, base_env=False, agent_ids=None): ): error = ( f"Your dones dictionary must have agent ids that belong to " - f"the environment. Agent_ids received from " - f"env.get_agent_ids() are: {agent_ids}" + f"the environment. AgentIDs received from " + f"env.agents are: {agent_ids}" ) raise ValueError(error) elif not isinstance(info, dict): @@ -257,7 +212,7 @@ def _check_if_element_multi_agent_dict( f" {type(element)}" ) raise ValueError(error) - agent_ids: Set = copy(env.get_agent_ids()) + agent_ids: Set = set(env.agents) agent_ids.add("__all__") if allow_common: agent_ids.add("__common__") @@ -268,18 +223,19 @@ def _check_if_element_multi_agent_dict( f"The element returned by {function_string} has agent_ids" f" that are not the names of the agents in the env." f"agent_ids in this\nMultiEnvDict:" - f" {list(element.keys())}\nAgent_ids in this env:" - f"{list(env.get_agent_ids())}" + f" {list(element.keys())}\nAgentIDs in this env: " + f"{env.agents}" ) else: error = ( f"The element returned by {function_string} has agent_ids" f" that are not the names of the agents in the env. " - f"\nAgent_ids in this MultiAgentDict: " - f"{list(element.keys())}\nAgent_ids in this env:" - f"{list(env.get_agent_ids())}. You likely need to add the private " - f"attribute `_agent_ids` to your env, which is a set containing the " - f"ids of agents supported by your env." + f"\nAgentIDs in this MultiAgentDict: " + f"{list(element.keys())}\nAgentIDs in this env: " + f"{env.agents}. You likely need to add the attribute `agents` to your " + f"env, which is a list containing the IDs of agents currently in your " + f"env/episode, as well as, `possible_agents`, which is a list of all " + f"possible agents that could ever show up in your env." ) raise ValueError(error) diff --git a/rllib/utils/replay_buffers/multi_agent_episode_buffer.py b/rllib/utils/replay_buffers/multi_agent_episode_buffer.py index 74b629304875..c395076cb4cc 100644 --- a/rllib/utils/replay_buffers/multi_agent_episode_buffer.py +++ b/rllib/utils/replay_buffers/multi_agent_episode_buffer.py @@ -59,8 +59,8 @@ class MultiAgentEpisodeReplayBuffer(EpisodeReplayBuffer): env = MultiAgentCartPole({"num_agents": 2}) # Set up the loop variables - agent_ids = env.get_agent_ids() - agent_ids.add("__all__") + agent_ids = env.agents + agent_ids.append("__all__") terminateds = {aid: False for aid in agent_ids} truncateds = {aid: False for aid in agent_ids} num_timesteps = 10000 @@ -82,13 +82,11 @@ class MultiAgentEpisodeReplayBuffer(EpisodeReplayBuffer): obs, infos = env.reset() eps.add_env_reset(observations=obs, infos=infos) - # Note, `action_space_sample` samples an action for all agents not only the - # ones still alive, but the `MultiAgentEpisode.add_env_step` does not accept - # results for dead agents. + # Sample a random action for all agents that should step in the episode + # next. actions = { - aid: act - for aid, act in env.action_space_sample().items() - if aid not in (env.terminateds or env.truncateds) + aid: env.get_action_space(aid).sample() + for aid in eps.get_agents_to_act() } obs, rewards, terminateds, truncateds, infos = env.step(actions) eps.add_env_step( @@ -481,14 +479,15 @@ def get_sampled_timesteps(self, module_id: Optional[ModuleID] = None) -> int: @override(EpisodeReplayBuffer) def get_added_timesteps(self, module_id: Optional[ModuleID] = None) -> int: - """Returns number of timesteps that have been added in buffer's lifetime for a module. + """Returns the number of timesteps added in buffer's lifetime for given module. Args: - module_id: The ID of the module to query. If not provided, the number of - + module_id: The ID of the module to query. If not provided, the total number + of timesteps ever added. Returns: - The number of timesteps added for the module or all modules. + The number of timesteps added for `module_id` (or all modules if `module_id` + is None). """ return ( self._num_module_timesteps_added[module_id] diff --git a/rllib/utils/replay_buffers/multi_agent_prioritized_episode_buffer.py b/rllib/utils/replay_buffers/multi_agent_prioritized_episode_buffer.py index f1067ee64a23..f0fa2e93681d 100755 --- a/rllib/utils/replay_buffers/multi_agent_prioritized_episode_buffer.py +++ b/rllib/utils/replay_buffers/multi_agent_prioritized_episode_buffer.py @@ -58,6 +58,7 @@ class MultiAgentPrioritizedEpisodeReplayBuffer( sampled timestep indices). .. testcode:: + import gymnasium as gym from ray.rllib.env.multi_agent_episode import MultiAgentEpisode @@ -71,8 +72,8 @@ class MultiAgentPrioritizedEpisodeReplayBuffer( env = MultiAgentCartPole({"num_agents": 2}) # Set up the loop variables - agent_ids = env.get_agent_ids() - agent_ids.add("__all__") + agent_ids = env.agents + agent_ids.append("__all__") terminateds = {aid: False for aid in agent_ids} truncateds = {aid: False for aid in agent_ids} num_timesteps = 10000 @@ -94,13 +95,11 @@ class MultiAgentPrioritizedEpisodeReplayBuffer( obs, infos = env.reset() eps.add_env_reset(observations=obs, infos=infos) - # Note, `action_space_sample` samples an action for all agents not only the - # ones still alive, but the `MultiAgentEpisode.add_env_step` does not accept - # results for dead agents. + # Sample a random action for all agents that should step in the episode + # next. actions = { - aid: act - for aid, act in env.action_space_sample().items() - if aid not in (env.terminateds or env.truncateds) + aid: env.get_action_space(aid).sample() + for aid in eps.get_agents_to_act() } obs, rewards, terminateds, truncateds, infos = env.step(actions) eps.add_env_step(