Skip to content

Commit

Permalink
[RLlib] MultiAgentEnv API enhancements (related to defining obs-/acti…
Browse files Browse the repository at this point in the history
…on spaces for agents). (#47830)
  • Loading branch information
sven1977 committed Sep 28, 2024
1 parent b676d02 commit e07594e
Show file tree
Hide file tree
Showing 25 changed files with 473 additions and 793 deletions.
209 changes: 125 additions & 84 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,12 +982,26 @@ def build_env_to_module_connector(self, env):
f"pipeline)! Your function returned {val_}."
)

obs_space = getattr(env, "single_observation_space", env.observation_space)
if obs_space is None and self.is_multi_agent():
obs_space = gym.spaces.Dict(
{
aid: env.get_observation_space(aid)
for aid in env.unwrapped.possible_agents
}
)
act_space = getattr(env, "single_action_space", env.action_space)
if act_space is None and self.is_multi_agent():
act_space = gym.spaces.Dict(
{
aid: env.get_action_space(aid)
for aid in env.unwrapped.possible_agents
}
)
pipeline = EnvToModulePipeline(
input_observation_space=obs_space,
input_action_space=act_space,
connectors=custom_connectors,
input_observation_space=getattr(
env, "single_observation_space", env.observation_space
),
input_action_space=getattr(env, "single_action_space", env.action_space),
)

if self.add_default_connectors_to_env_to_module_pipeline:
Expand Down Expand Up @@ -1048,12 +1062,26 @@ def build_module_to_env_connector(self, env):
f"pipeline)! Your function returned {val_}."
)

obs_space = getattr(env, "single_observation_space", env.observation_space)
if obs_space is None and self.is_multi_agent():
obs_space = gym.spaces.Dict(
{
aid: env.get_observation_space(aid)
for aid in env.unwrapped.possible_agents
}
)
act_space = getattr(env, "single_action_space", env.action_space)
if act_space is None and self.is_multi_agent():
act_space = gym.spaces.Dict(
{
aid: env.get_action_space(aid)
for aid in env.unwrapped.possible_agents
}
)
pipeline = ModuleToEnvPipeline(
input_observation_space=obs_space,
input_action_space=act_space,
connectors=custom_connectors,
input_observation_space=getattr(
env, "single_observation_space", env.observation_space
),
input_action_space=getattr(env, "single_action_space", env.action_space),
)

if self.add_default_connectors_to_module_to_env_pipeline:
Expand Down Expand Up @@ -4916,47 +4944,54 @@ def get_multi_agent_setup(

# Infer observation space.
if policy_spec.observation_space is None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Module's space is provided -> Use it as-is.
if spaces is not None and pid in spaces:
obs_space = spaces[pid][0]
elif env_obs_space is not None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Multi-agent case AND different agents have different spaces:
# Need to reverse map spaces (for the different agents) to certain
# policy IDs.
if (
isinstance(env_unwrapped, MultiAgentEnv)
and hasattr(env_unwrapped, "_obs_space_in_preferred_format")
and env_unwrapped._obs_space_in_preferred_format
):
obs_space = None
mapping_fn = self.policy_mapping_fn
one_obs_space = next(iter(env_obs_space.values()))
# If all obs spaces are the same anyways, just use the first
# single-agent space.
if all(s == one_obs_space for s in env_obs_space.values()):
obs_space = one_obs_space
# Otherwise, we have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in env_unwrapped.get_agent_ids():
# Match: Assign spaces for this agentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
obs_space is not None
and env_obs_space[aid] != obs_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different observation spaces!"
)
obs_space = env_obs_space[aid]
# Otherwise, just use env's obs space as-is.
# MultiAgentEnv -> Check, whether agents have different spaces.
elif isinstance(env_unwrapped, MultiAgentEnv):
obs_space = None
mapping_fn = self.policy_mapping_fn
aids = list(
env_unwrapped.possible_agents
if hasattr(env_unwrapped, "possible_agents")
and env_unwrapped.possible_agents
else env_unwrapped.get_agent_ids()
)
if len(aids) == 0:
one_obs_space = env_unwrapped.observation_space
else:
obs_space = env_obs_space
one_obs_space = env_unwrapped.get_observation_space(aids[0])
# If all obs spaces are the same, just use the first space.
if all(
env_unwrapped.get_observation_space(aid) == one_obs_space
for aid in aids
):
obs_space = one_obs_space
# Need to reverse-map spaces (for the different agents) to certain
# policy IDs. We have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in aids:
# Match: Assign spaces for this agentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
obs_space is not None
and env_unwrapped.get_observation_space(aid)
!= obs_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different observation spaces!"
)
obs_space = env_unwrapped.get_observation_space(aid)
# Just use env's obs space as-is.
elif env_obs_space is not None:
obs_space = env_obs_space
# Space given directly in config.
elif self.observation_space:
obs_space = self.observation_space
Expand All @@ -4972,47 +5007,53 @@ def get_multi_agent_setup(

# Infer action space.
if policy_spec.action_space is None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Module's space is provided -> Use it as-is.
if spaces is not None and pid in spaces:
act_space = spaces[pid][1]
elif env_act_space is not None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Multi-agent case AND different agents have different spaces:
# Need to reverse map spaces (for the different agents) to certain
# policy IDs.
if (
isinstance(env_unwrapped, MultiAgentEnv)
and hasattr(env_unwrapped, "_action_space_in_preferred_format")
and env_unwrapped._action_space_in_preferred_format
):
act_space = None
mapping_fn = self.policy_mapping_fn
one_act_space = next(iter(env_act_space.values()))
# If all action spaces are the same anyways, just use the first
# single-agent space.
if all(s == one_act_space for s in env_act_space.values()):
act_space = one_act_space
# Otherwise, we have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in env_unwrapped.get_agent_ids():
# Match: Assign spaces for this AgentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
act_space is not None
and env_act_space[aid] != act_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different action spaces!"
)
act_space = env_act_space[aid]
# Otherwise, just use env's action space as-is.
# MultiAgentEnv -> Check, whether agents have different spaces.
elif isinstance(env_unwrapped, MultiAgentEnv):
act_space = None
mapping_fn = self.policy_mapping_fn
aids = list(
env_unwrapped.possible_agents
if hasattr(env_unwrapped, "possible_agents")
and env_unwrapped.possible_agents
else env_unwrapped.get_agent_ids()
)
if len(aids) == 0:
one_act_space = env_unwrapped.action_space
else:
act_space = env_act_space
one_act_space = env_unwrapped.get_action_space(aids[0])
# If all obs spaces are the same, just use the first space.
if all(
env_unwrapped.get_action_space(aid) == one_act_space
for aid in aids
):
act_space = one_act_space
# Need to reverse-map spaces (for the different agents) to certain
# policy IDs. We have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in aids:
# Match: Assign spaces for this AgentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
act_space is not None
and env_unwrapped.get_action_space(aid) != act_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different action spaces!"
)
act_space = env_unwrapped.get_action_space(aid)
# Just use env's action space as-is.
elif env_act_space is not None:
act_space = env_act_space
elif self.action_space:
act_space = self.action_space
else:
Expand Down
28 changes: 14 additions & 14 deletions rllib/algorithms/tests/test_algorithm_rl_module_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_e2e_load_simple_multi_rl_module(self):
for i in range(NUM_AGENTS):
module_specs[f"policy_{i}"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# If we want to use this externally created module in the algorithm,
# we need to provide the same config as the algorithm.
model_config_dict=config.model_config
Expand Down Expand Up @@ -115,8 +115,8 @@ def test_e2e_load_complex_multi_rl_module(self):
for i in range(NUM_AGENTS):
module_specs[f"policy_{i}"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# If we want to use this externally created module in the algorithm,
# we need to provide the same config as the algorithm.
model_config_dict=config.model_config
Expand All @@ -131,8 +131,8 @@ def test_e2e_load_complex_multi_rl_module(self):
# create a RLModule to load and override the "policy_1" module with
module_to_swap_in = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# Note, we need to pass in the default model config for the algorithm
# to be able to use this module later.
model_config_dict=config.model_config | {"fcnet_hiddens": [64]},
Expand All @@ -146,8 +146,8 @@ def test_e2e_load_complex_multi_rl_module(self):
# and the module_to_swap_in_checkpoint
module_specs["policy_1"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [64]},
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
Expand Down Expand Up @@ -258,8 +258,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
for i in range(num_agents):
module_specs[f"policy_{i}"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# Note, we need to pass in the default model config for the
# algorithm to be able to use this module later.
model_config_dict=config.model_config
Expand All @@ -274,8 +274,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
# create a RLModule to load and override the "policy_1" module with
module_to_swap_in = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# Note, we need to pass in the default model config for the algorithm
# to be able to use this module later.
model_config_dict=config.model_config | {"fcnet_hiddens": [64]},
Expand All @@ -289,8 +289,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
# and the module_to_swap_in_checkpoint
module_specs["policy_1"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [64]},
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
Expand Down
Loading

0 comments on commit e07594e

Please sign in to comment.