From ff7d345dc2383df1129051d312fae607a6703805 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Mon, 20 Nov 2023 18:13:19 +0200 Subject: [PATCH 01/10] feat: general wrapper mostly complete Global state wrapper not yet working properly --- mava/systems/ff_ippo_rware.py | 8 ++-- mava/systems/ff_mappo_rware.py | 9 ++-- mava/systems/rec_ippo_rware.py | 11 +++-- mava/systems/rec_mappo_rware.py | 12 +++-- mava/types.py | 17 ++++--- mava/wrappers/jumanji.py | 84 +++++++++++++++------------------ 6 files changed, 70 insertions(+), 71 deletions(-) diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py index bf8648f94..c4df33c87 100644 --- a/mava/systems/ff_ippo_rware.py +++ b/mava/systems/ff_ippo_rware.py @@ -30,7 +30,6 @@ from flax.linen.initializers import constant, orthogonal from jumanji.env import Environment from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator -from jumanji.types import Observation from jumanji.wrappers import AutoResetWrapper from omegaconf import DictConfig, OmegaConf from optax._src.base import OptState @@ -44,6 +43,7 @@ ExperimentOutput, LearnerFn, LearnerState, + Observation, OptStates, Params, PPOTransition, @@ -51,7 +51,7 @@ from mava.utils.jax import merge_leading_dims from mava.utils.logger_tools import get_sacred_exp from mava.utils.timing_utils import TimeIt -from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareMultiAgentWrapper +from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper class Actor(nn.Module): @@ -497,14 +497,14 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None: # Create envs generator = RandomGenerator(**config["rware_scenario"]["task_config"]) env = jumanji.make(config["env_name"], generator=generator) - env = RwareMultiAgentWrapper(env) + env = RwareWrapper(env) # Add agent id to observation. if config["add_agent_id"]: env = AgentIDWrapper(env) env = AutoResetWrapper(env) env = LogWrapper(env) eval_env = jumanji.make(config["env_name"], generator=generator) - eval_env = RwareMultiAgentWrapper(eval_env) + eval_env = RwareWrapper(eval_env) if config["add_agent_id"]: eval_env = AgentIDWrapper(eval_env) diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py index cafb98dff..10d6e5159 100644 --- a/mava/systems/ff_mappo_rware.py +++ b/mava/systems/ff_mappo_rware.py @@ -29,7 +29,6 @@ from flax.linen.initializers import constant, orthogonal from jumanji.env import Environment from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator -from jumanji.types import Observation from jumanji.wrappers import AutoResetWrapper from omegaconf import DictConfig, OmegaConf from optax._src.base import OptState @@ -43,6 +42,7 @@ ExperimentOutput, LearnerFn, LearnerState, + Observation, OptStates, Params, PPOTransition, @@ -52,9 +52,10 @@ from mava.utils.timing_utils import TimeIt from mava.wrappers.jumanji import ( AgentIDWrapper, + GlobalStateWrapper, LogWrapper, ObservationGlobalState, - RwareMultiAgentWithGlobalStateWrapper, + RwareWrapper, ) @@ -506,14 +507,14 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None: # Create envs generator = RandomGenerator(**config["rware_scenario"]["task_config"]) env = jumanji.make(config["env_name"], generator=generator) - env = RwareMultiAgentWithGlobalStateWrapper(env) + env = GlobalStateWrapper(RwareWrapper(env)) # Add agent id to observation. if config["add_agent_id"]: env = AgentIDWrapper(env=env, has_global_state=True) env = AutoResetWrapper(env) env = LogWrapper(env) eval_env = jumanji.make(config["env_name"], generator=generator) - eval_env = RwareMultiAgentWithGlobalStateWrapper(eval_env) + eval_env = GlobalStateWrapper(RwareWrapper(env)) if config["add_agent_id"]: eval_env = AgentIDWrapper(env=eval_env, has_global_state=True) diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py index be1976966..9f5f50df8 100644 --- a/mava/systems/rec_ippo_rware.py +++ b/mava/systems/rec_ippo_rware.py @@ -48,10 +48,11 @@ RecActorApply, RecCriticApply, RNNLearnerState, + RnnObservation, ) from mava.utils.logger_tools import get_sacred_exp from mava.utils.timing_utils import TimeIt -from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareMultiAgentWrapper +from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper class ScannedRNN(nn.Module): @@ -91,7 +92,7 @@ class Actor(nn.Module): def __call__( self, policy_hidden_state: chex.Array, - observation_done: Tuple[chex.Array, chex.Array], + observation_done: RnnObservation, ) -> Tuple[chex.Array, distrax.Categorical]: """Forward pass.""" observation, done = observation_done @@ -130,7 +131,7 @@ class Critic(nn.Module): def __call__( self, critic_hidden_state: Tuple[chex.Array, chex.Array], - observation_done: Tuple[chex.Array, chex.Array], + observation_done: RnnObservation, ) -> Tuple[chex.Array, chex.Array]: """Forward pass.""" observation, done = observation_done @@ -677,14 +678,14 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None: # Create envs generator = RandomGenerator(**config["rware_scenario"]["task_config"]) env = jumanji.make(config["env_name"], generator=generator) - env = RwareMultiAgentWrapper(env) + env = RwareWrapper(env) # Add agent id to observation. if config["add_agent_id"]: env = AgentIDWrapper(env) env = AutoResetWrapper(env) env = LogWrapper(env) eval_env = jumanji.make(config["env_name"], generator=generator) - eval_env = RwareMultiAgentWrapper(eval_env) + eval_env = RwareWrapper(eval_env) if config["add_agent_id"]: eval_env = AgentIDWrapper(eval_env) diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py index a53390040..e92146af6 100644 --- a/mava/systems/rec_mappo_rware.py +++ b/mava/systems/rec_mappo_rware.py @@ -48,14 +48,16 @@ RecActorApply, RecCriticApply, RNNLearnerState, + RnnObservation, ) from mava.utils.logger_tools import get_sacred_exp from mava.utils.timing_utils import TimeIt from mava.wrappers.jumanji import ( AgentIDWrapper, + GlobalStateWrapper, LogWrapper, ObservationGlobalState, - RwareMultiAgentWithGlobalStateWrapper, + RwareWrapper, ) @@ -96,7 +98,7 @@ class Actor(nn.Module): def __call__( self, policy_hidden_state: chex.Array, - observation_done: Tuple[chex.Array, chex.Array], + observation_done: RnnObservation, ) -> Tuple[chex.Array, distrax.Categorical]: """Forward pass.""" observation, done = observation_done @@ -135,7 +137,7 @@ class Critic(nn.Module): def __call__( self, critic_hidden_state: Tuple[chex.Array, chex.Array], - observation_done: Tuple[chex.Array, chex.Array], + observation_done: RnnObservation, ) -> Tuple[chex.Array, chex.Array]: """Forward pass.""" observation, done = observation_done @@ -686,14 +688,14 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None: # Create envs generator = RandomGenerator(**config["rware_scenario"]["task_config"]) env = jumanji.make(config["env_name"], generator=generator) - env = RwareMultiAgentWithGlobalStateWrapper(env) + env = GlobalStateWrapper(RwareWrapper(env)) # Add agent id to observation. if config["add_agent_id"]: env = AgentIDWrapper(env=env, has_global_state=True) env = AutoResetWrapper(env) env = LogWrapper(env) eval_env = jumanji.make(config["env_name"], generator=generator) - eval_env = RwareMultiAgentWithGlobalStateWrapper(eval_env) + eval_env = GlobalStateWrapper(RwareWrapper(env)) if config["add_agent_id"]: eval_env = AgentIDWrapper(env=eval_env, has_global_state=True) diff --git a/mava/types.py b/mava/types.py index d7db2f30e..9da2adf29 100644 --- a/mava/types.py +++ b/mava/types.py @@ -21,16 +21,21 @@ from optax._src.base import OptState from typing_extensions import NamedTuple, TypeAlias -from mava.wrappers.jumanji import LogEnvState - Action: TypeAlias = chex.Array Value: TypeAlias = chex.Array Done: TypeAlias = chex.Array HiddenState: TypeAlias = chex.Array -# Can't know the exact type of State or Timestep. +# Can't know the exact type of State. State: TypeAlias = Any -Observation: TypeAlias = Any + + +class Observation(NamedTuple): + agents_view: chex.Array + action_mask: chex.Array + step_count: chex.Numeric + + RnnObservation: TypeAlias = Tuple[Observation, Done] @@ -73,7 +78,7 @@ class LearnerState(NamedTuple): params: Params opt_states: OptStates key: chex.PRNGKey - env_state: LogEnvState + env_state: State timestep: TimeStep @@ -83,7 +88,7 @@ class RNNLearnerState(NamedTuple): params: Params opt_states: OptStates key: chex.PRNGKey - env_state: LogEnvState + env_state: State timestep: TimeStep dones: Done hstates: HiddenStates diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 847c0cde5..206096d79 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -18,10 +18,12 @@ import jax.numpy as jnp from jumanji import specs from jumanji.env import Environment -from jumanji.environments.routing.robot_warehouse import Observation, State +from jumanji.environments.routing.robot_warehouse import RobotWarehouse, State from jumanji.types import TimeStep from jumanji.wrappers import Wrapper +from mava.types import Observation + if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 from dataclasses import dataclass else: @@ -169,28 +171,31 @@ def observation_spec(self) -> specs.Spec[Observation]: ) -class RwareMultiAgentWrapper(Wrapper): +class RwareWrapper(Wrapper): """Multi-agent wrapper for the Robotic Warehouse environment.""" - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: - """Reset the environment. Updates the step count.""" - state, timestep = self._env.reset(key) - timestep.observation = Observation( + def __init__(self, env: RobotWarehouse): + super().__init__(env) + self._env: RobotWarehouse + + def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]: + observation = Observation( agents_view=timestep.observation.agents_view, action_mask=timestep.observation.action_mask, step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents), ) - return state, timestep + + return timestep.replace(observation=observation) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + """Reset the environment. Updates the step count.""" + state, timestep = self._env.reset(key) + return state, self.convert_timstep(timestep) def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """Step the environment. Updates the step count.""" state, timestep = self._env.step(state, action) - timestep.observation = Observation( - agents_view=timestep.observation.agents_view, - action_mask=timestep.observation.action_mask, - step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents), - ) - return state, timestep + return state, self.convert_timstep(timestep) def observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the `RobotWarehouse` environment.""" @@ -204,67 +209,52 @@ def observation_spec(self) -> specs.Spec[Observation]: return self._env.observation_spec().replace(step_count=step_count) -class RwareMultiAgentWithGlobalStateWrapper(Wrapper): - """Multi-agent wrapper for the Robotic Warehouse environment. +class GlobalStateWrapper(Wrapper): + """Wrapper for adding global state to an environment that follows the mava API. The wrapper includes a global environment state to be used by the centralised critic. Note here that since robotic warehouse does not have a global state, we create one by concatenating the observations of all agents. """ - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: - """Reset the environment. Updates the step count.""" - state, timestep = self._env.reset(key) + def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]: global_state = jnp.concatenate(timestep.observation.agents_view, axis=0) global_state = jnp.tile(global_state, (self._env.num_agents, 1)) - timestep.observation = ObservationGlobalState( + + observation = ObservationGlobalState( + global_state=global_state, agents_view=timestep.observation.agents_view, action_mask=timestep.observation.action_mask, - global_state=global_state, - step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents), + step_count=timestep.observation.step_count, ) - return state, timestep + + return timestep.replace(observation=observation) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + """Reset the environment. Updates the step count.""" + state, timestep = self._env.reset(key) + return state, self.convert_timstep(timestep) def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """Step the environment. Updates the step count.""" state, timestep = self._env.step(state, action) - global_state = jnp.concatenate(timestep.observation.agents_view, axis=0) - global_state = jnp.tile(global_state, (self._env.num_agents, 1)) - timestep.observation = ObservationGlobalState( - agents_view=timestep.observation.agents_view, - action_mask=timestep.observation.action_mask, - global_state=global_state, - step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents), - ) - return state, timestep + return state, self.convert_timstep(timestep) def observation_spec(self) -> specs.Spec[ObservationGlobalState]: """Specification of the observation of the `RobotWarehouse` environment.""" - agents_view = specs.Array( - (self._env.num_agents, self._env.num_obs_features), jnp.int32, "agents_view" - ) - action_mask = specs.BoundedArray( - (self._env.num_agents, 5), bool, False, True, "action_mask" - ) + obs_spec = self._env.observation_spec() global_state = specs.Array( (self._env.num_agents, self._env.num_agents * self._env.num_obs_features), jnp.int32, "global_state", ) - step_count = specs.BoundedArray( - (self._env.num_agents,), - jnp.int32, - [0] * self._env.num_agents, - [self._env.time_limit] * self._env.num_agents, - "step_count", - ) return specs.Spec( ObservationGlobalState, "ObservationSpec", - agents_view=agents_view, - action_mask=action_mask, + agents_view=obs_spec.agents_view, + action_mask=obs_spec.action_mask, global_state=global_state, - step_count=step_count, + step_count=obs_spec.step_count, ) From b1f33fa1ab613c9d903161724c1bdf333e2487f7 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 21 Nov 2023 10:45:05 +0200 Subject: [PATCH 02/10] fix: centralized critic eval bug with new wrapper --- mava/systems/ff_mappo_rware.py | 2 +- mava/systems/rec_mappo_rware.py | 2 +- mava/wrappers/jumanji.py | 21 +++++++++++++-------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py index 10d6e5159..e6dc4a4de 100644 --- a/mava/systems/ff_mappo_rware.py +++ b/mava/systems/ff_mappo_rware.py @@ -514,7 +514,7 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None: env = AutoResetWrapper(env) env = LogWrapper(env) eval_env = jumanji.make(config["env_name"], generator=generator) - eval_env = GlobalStateWrapper(RwareWrapper(env)) + eval_env = GlobalStateWrapper(RwareWrapper(eval_env)) if config["add_agent_id"]: eval_env = AgentIDWrapper(env=eval_env, has_global_state=True) diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py index e92146af6..b69223375 100644 --- a/mava/systems/rec_mappo_rware.py +++ b/mava/systems/rec_mappo_rware.py @@ -695,7 +695,7 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None: env = AutoResetWrapper(env) env = LogWrapper(env) eval_env = jumanji.make(config["env_name"], generator=generator) - eval_env = GlobalStateWrapper(RwareWrapper(env)) + eval_env = GlobalStateWrapper(RwareWrapper(eval_env)) if config["add_agent_id"]: eval_env = AgentIDWrapper(env=eval_env, has_global_state=True) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 206096d79..ee14153a1 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -178,24 +178,29 @@ def __init__(self, env: RobotWarehouse): super().__init__(env) self._env: RobotWarehouse - def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: + n_agents = self._env.num_agents observation = Observation( agents_view=timestep.observation.agents_view, action_mask=timestep.observation.action_mask, - step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents), + step_count=jnp.repeat(timestep.observation.step_count, n_agents), ) + # todo (weim): get this working so that ppo always expects (n_agents,) size for reward and discount + # reward = jnp.repeat(timestep.reward, n_agents) + # discount = jnp.repeat(timestep.discount, n_agents) + # return timestep.replace(observation=observation, reward=reward, discount=discount) return timestep.replace(observation=observation) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """Reset the environment. Updates the step count.""" state, timestep = self._env.reset(key) - return state, self.convert_timstep(timestep) + return state, self.modify_timestep(timestep) def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """Step the environment. Updates the step count.""" state, timestep = self._env.step(state, action) - return state, self.convert_timstep(timestep) + return state, self.modify_timestep(timestep) def observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the `RobotWarehouse` environment.""" @@ -213,11 +218,11 @@ class GlobalStateWrapper(Wrapper): """Wrapper for adding global state to an environment that follows the mava API. The wrapper includes a global environment state to be used by the centralised critic. - Note here that since robotic warehouse does not have a global state, we create one + Note here that since most environments do not have a global state, we create one by concatenating the observations of all agents. """ - def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: global_state = jnp.concatenate(timestep.observation.agents_view, axis=0) global_state = jnp.tile(global_state, (self._env.num_agents, 1)) @@ -233,12 +238,12 @@ def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]: def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """Reset the environment. Updates the step count.""" state, timestep = self._env.reset(key) - return state, self.convert_timstep(timestep) + return state, self.modify_timestep(timestep) def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """Step the environment. Updates the step count.""" state, timestep = self._env.step(state, action) - return state, self.convert_timstep(timestep) + return state, self.modify_timestep(timestep) def observation_spec(self) -> specs.Spec[ObservationGlobalState]: """Specification of the observation of the `RobotWarehouse` environment.""" From b520cf391ba798dc79df3306170c9d1624c4ad7f Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 21 Nov 2023 20:34:11 +0200 Subject: [PATCH 03/10] feat: repeat reward and discount in wrapper --- mava/evaluator.py | 4 +++- mava/systems/ff_ippo_rware.py | 7 ++----- mava/wrappers/jumanji.py | 9 +++------ 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/mava/evaluator.py b/mava/evaluator.py index ce6c7c221..5c740304b 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -105,7 +105,9 @@ def evaluator_fn(trained_params: FrozenDict, rng: chex.PRNGKey) -> ExperimentOut # Add dimension to pmap over. step_rngs = jnp.stack(step_rngs).reshape(eval_batch, -1) - eval_state = EvalState(step_rngs, env_states, timesteps, 0, 0.0) + eval_state = EvalState( + step_rngs, env_states, timesteps, 0, jnp.zeros_like(timesteps.reward) + ) eval_metrics = jax.vmap( eval_one_episode, in_axes=(None, EvalState(0, 0, 0, None, None)), diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py index c4df33c87..09686a616 100644 --- a/mava/systems/ff_ippo_rware.py +++ b/mava/systems/ff_ippo_rware.py @@ -151,17 +151,14 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # LOG EPISODE METRICS - done, reward = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config["num_agents"]).reshape(config["num_envs"], -1), - (timestep.last(), timestep.reward), - ) info = { "episode_return": env_state.episode_return_info, "episode_length": env_state.episode_length_info, } + done = 1 - timestep.discount transition = PPOTransition( - done, action, value, reward, log_prob, last_timestep.observation, info + done, action, value, timestep.reward, log_prob, last_timestep.observation, info ) learner_state = LearnerState(params, opt_states, rng, env_state, timestep) return learner_state, transition diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index ee14153a1..18ab1482a 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -185,12 +185,9 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: action_mask=timestep.observation.action_mask, step_count=jnp.repeat(timestep.observation.step_count, n_agents), ) - # todo (weim): get this working so that ppo always expects (n_agents,) size for reward and discount - # reward = jnp.repeat(timestep.reward, n_agents) - # discount = jnp.repeat(timestep.discount, n_agents) - # return timestep.replace(observation=observation, reward=reward, discount=discount) - - return timestep.replace(observation=observation) + reward = jnp.repeat(timestep.reward, n_agents) + discount = jnp.repeat(timestep.discount, n_agents) + return timestep.replace(observation=observation, reward=reward, discount=discount) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """Reset the environment. Updates the step count.""" From 4236d33448fb9d640080d230c62a50e3a1a0d88f Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Wed, 22 Nov 2023 19:32:44 +0100 Subject: [PATCH 04/10] feat: edit types and wrapper for other systems --- mava/configs/arch/anakin.yaml | 2 +- mava/evaluator.py | 2 +- mava/systems/ff_ippo_rware.py | 8 +++++- mava/systems/ff_mappo_rware.py | 8 +++--- mava/systems/rec_ippo_rware.py | 6 ++--- mava/systems/rec_mappo_rware.py | 6 ++--- mava/types.py | 41 +++++++++++++++++++++++++++--- mava/wrappers/jumanji.py | 45 +++++---------------------------- 8 files changed, 62 insertions(+), 56 deletions(-) diff --git a/mava/configs/arch/anakin.yaml b/mava/configs/arch/anakin.yaml index 86e75898b..ef9aa4314 100644 --- a/mava/configs/arch/anakin.yaml +++ b/mava/configs/arch/anakin.yaml @@ -1,7 +1,7 @@ # --- Anakin config --- # --- Training --- -num_envs: 16 # Number of vectorised environments per device. +num_envs: 256 # Number of vectorised environments per device. # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select diff --git a/mava/evaluator.py b/mava/evaluator.py index 7c2fa04c7..ac77c7723 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -233,7 +233,7 @@ def evaluator_fn( dones=dones, hstate=init_hstate, step_count_=0, - return_=0.0, + return_=jnp.zeros_like(timesteps.reward), ) eval_metrics = jax.vmap( diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py index 0e16c3916..14e0b7529 100644 --- a/mava/systems/ff_ippo_rware.py +++ b/mava/systems/ff_ippo_rware.py @@ -157,7 +157,13 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra "episode_length": env_state.episode_length_info, } - done = 1 - timestep.discount + # done = 1 - timestep.discount + done = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( + config["arch"]["num_envs"], -1 + ), + timestep.last(), + ) transition = PPOTransition( done, action, value, timestep.reward, log_prob, last_timestep.observation, info ) diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py index f5021bd5a..00b824d2b 100644 --- a/mava/systems/ff_mappo_rware.py +++ b/mava/systems/ff_mappo_rware.py @@ -157,11 +157,11 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # LOG EPISODE METRICS - done, reward = jax.tree_util.tree_map( + done = jax.tree_util.tree_map( lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( config["arch"]["num_envs"], -1 ), - (timestep.last(), timestep.reward), + timestep.last(), ) info = { "episode_return": env_state.episode_return_info, @@ -169,7 +169,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra } transition = PPOTransition( - done, action, value, reward, log_prob, last_timestep.observation, info + done, action, value, timestep.reward, log_prob, last_timestep.observation, info ) learner_state = LearnerState(params, opt_states, rng, env_state, timestep) return learner_state, transition @@ -526,7 +526,7 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None: env = LogWrapper(env) eval_env = jumanji.make(config["env"]["env_name"], generator=generator) eval_env = GlobalStateWrapper(RwareWrapper(eval_env)) - if config["add_agent_id"]: + if config["system"]["add_agent_id"]: eval_env = AgentIDWrapper(env=eval_env, has_global_state=True) # PRNG keys. diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py index 4a3ac847c..f151d2cba 100644 --- a/mava/systems/rec_ippo_rware.py +++ b/mava/systems/rec_ippo_rware.py @@ -229,11 +229,11 @@ def _env_step( env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # log episode return and length - done, reward = jax.tree_util.tree_map( + done = jax.tree_util.tree_map( lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( config["arch"]["num_envs"], -1 ), - (timestep.last(), timestep.reward), + timestep.last(), ) info = { "episode_return": env_state.episode_return_info, @@ -241,7 +241,7 @@ def _env_step( } transition = PPOTransition( - done, action, value, reward, log_prob, last_timestep.observation, info + done, action, value, timestep.reward, log_prob, last_timestep.observation, info ) hstates = HiddenStates(policy_hidden_state, critic_hidden_state) learner_state = RNNLearnerState( diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py index 05699c000..afd4d4ac3 100644 --- a/mava/systems/rec_mappo_rware.py +++ b/mava/systems/rec_mappo_rware.py @@ -229,11 +229,11 @@ def _env_step( env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # log episode return and length - done, reward = jax.tree_util.tree_map( + done = jax.tree_util.tree_map( lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( config["arch"]["num_envs"], -1 ), - (timestep.last(), timestep.reward), + timestep.last(), ) info = { "episode_return": env_state.episode_return_info, @@ -241,7 +241,7 @@ def _env_step( } transition = PPOTransition( - done, action, value, reward, log_prob, last_timestep.observation, info + done, action, value, timestep.reward, log_prob, last_timestep.observation, info ) hstates = HiddenStates(policy_hidden_state, critic_hidden_state) learner_state = RNNLearnerState( diff --git a/mava/types.py b/mava/types.py index b888a3f16..7455e9b2f 100644 --- a/mava/types.py +++ b/mava/types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Tuple, TypeVar import chex from distrax import Distribution @@ -21,13 +21,16 @@ from optax._src.base import OptState from typing_extensions import NamedTuple, TypeAlias -from mava.wrappers.jumanji import LogEnvState +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from flax.struct import dataclass + Action: TypeAlias = chex.Array Value: TypeAlias = chex.Array Done: TypeAlias = chex.Array HiddenState: TypeAlias = chex.Array - # Can't know the exact type of State. State: TypeAlias = Any @@ -38,7 +41,37 @@ class Observation(NamedTuple): step_count: chex.Numeric -RnnObservation: TypeAlias = Tuple[Observation, Done] +class ObservationGlobalState(NamedTuple): + """The observation that the agent sees. + agents_view: the agents' view of other agents and shelves within their + sensor range. The number of features in the observation array + depends on the sensor range of the agent. + action_mask: boolean array specifying, for each agent, which action + (up, right, down, left) is legal. + global_state: the global state of the environment, which is the + concatenation of the agents' views. + step_count: the number of steps elapsed since the beginning of the episode. + """ + + agents_view: chex.Array # (num_agents, num_obs_features) + action_mask: chex.Array # (num_agents, num_actions) + global_state: chex.Array # (num_agents, num_agents * num_obs_features, ) + step_count: chex.Array # (num_agents, ) + + +@dataclass +class LogEnvState: + """State of the `LogWrapper`.""" + + env_state: State + episode_returns: chex.Numeric + episode_lengths: chex.Numeric + # Information about the episode return and length for logging purposes. + episode_return_info: chex.Numeric + episode_length_info: chex.Numeric + + +RnnObservation: TypeAlias = Tuple[ObservationGlobalState, Done] class PPOTransition(NamedTuple): diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 0b9e87475..aec4c0ed0 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, NamedTuple, Tuple, Union +from typing import Tuple, Union import chex import jax.numpy as jnp @@ -22,42 +22,7 @@ from jumanji.types import TimeStep from jumanji.wrappers import Wrapper -from mava.types import Observation - -if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 - from dataclasses import dataclass -else: - from flax.struct import dataclass - - -class ObservationGlobalState(NamedTuple): - """The observation that the agent sees. - agents_view: the agents' view of other agents and shelves within their - sensor range. The number of features in the observation array - depends on the sensor range of the agent. - action_mask: boolean array specifying, for each agent, which action - (up, right, down, left) is legal. - global_state: the global state of the environment, which is the - concatenation of the agents' views. - step_count: the number of steps elapsed since the beginning of the episode. - """ - - agents_view: chex.Array # (num_agents, num_obs_features) - action_mask: chex.Array # (num_agents, num_actions) - global_state: chex.Array # (num_agents, num_agents * num_obs_features, ) - step_count: chex.Array # (num_agents, ) - - -@dataclass -class LogEnvState: - """State of the `LogWrapper`.""" - - env_state: State - episode_returns: chex.Numeric - episode_lengths: chex.Numeric - # Information about the episode return and length for logging purposes. - episode_return_info: chex.Numeric - episode_length_info: chex.Numeric +from mava.types import LogEnvState, Observation, ObservationGlobalState class LogWrapper(Wrapper): @@ -186,8 +151,10 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: step_count=jnp.repeat(timestep.observation.step_count, n_agents), ) reward = jnp.repeat(timestep.reward, n_agents) - discount = jnp.repeat(timestep.discount, n_agents) - return timestep.replace(observation=observation, reward=reward, discount=discount) + # discount = jnp.repeat(timestep.discount, n_agents) + # -> we won't need this if we'll use the timestep.last() for the 'done' + # variable during training. + return timestep.replace(observation=observation, reward=reward) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """Reset the environment. Updates the step count.""" From 3117226e1ddebefd82a55acc8100ed83d41f9376 Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Thu, 23 Nov 2023 15:43:54 +0100 Subject: [PATCH 05/10] feat: edit rec systems --- docs/jumanji_rware_comparison.md | 2 +- mava/systems/ff_ippo_rware.py | 8 +------- mava/systems/ff_mappo_rware.py | 9 ++------- mava/systems/rec_ippo_rware.py | 11 +++-------- mava/systems/rec_mappo_rware.py | 19 +++++++------------ mava/types.py | 3 ++- mava/wrappers/jumanji.py | 9 ++++----- 7 files changed, 20 insertions(+), 41 deletions(-) diff --git a/docs/jumanji_rware_comparison.md b/docs/jumanji_rware_comparison.md index 0613c9ad3..3d041ef12 100644 --- a/docs/jumanji_rware_comparison.md +++ b/docs/jumanji_rware_comparison.md @@ -66,7 +66,7 @@ Please see below for Mava's recurrent and feedforward implementations of IPPO an Mava rec mappo small 4ag
-
Mava recurrent IPPO performance on the tiny-2ag, tiny-4ag and small-4ag RWARE tasks.
+
Mava recurrent MAPPO performance on the tiny-2ag, tiny-4ag and small-4ag RWARE tasks.

diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py index 14e0b7529..6902897dc 100644 --- a/mava/systems/ff_ippo_rware.py +++ b/mava/systems/ff_ippo_rware.py @@ -152,18 +152,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # LOG EPISODE METRICS + done = 1 - timestep.discount info = { "episode_return": env_state.episode_return_info, "episode_length": env_state.episode_length_info, } - # done = 1 - timestep.discount - done = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( - config["arch"]["num_envs"], -1 - ), - timestep.last(), - ) transition = PPOTransition( done, action, value, timestep.reward, log_prob, last_timestep.observation, info ) diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py index 00b824d2b..131ee12af 100644 --- a/mava/systems/ff_mappo_rware.py +++ b/mava/systems/ff_mappo_rware.py @@ -45,6 +45,7 @@ LearnerFn, LearnerState, Observation, + ObservationGlobalState, OptStates, Params, PPOTransition, @@ -55,7 +56,6 @@ AgentIDWrapper, GlobalStateWrapper, LogWrapper, - ObservationGlobalState, RwareWrapper, ) @@ -157,12 +157,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # LOG EPISODE METRICS - done = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( - config["arch"]["num_envs"], -1 - ), - timestep.last(), - ) + done = 1 - timestep.discount info = { "episode_return": env_state.episode_return_info, "episode_length": env_state.episode_length_info, diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py index f151d2cba..9dc280e1f 100644 --- a/mava/systems/rec_ippo_rware.py +++ b/mava/systems/rec_ippo_rware.py @@ -229,12 +229,7 @@ def _env_step( env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # log episode return and length - done = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( - config["arch"]["num_envs"], -1 - ), - timestep.last(), - ) + done = 1 - timestep.discount info = { "episode_return": env_state.episode_return_info, "episode_length": env_state.episode_length_info, @@ -586,7 +581,7 @@ def learner_setup( init_obs, ) init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs) - init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=bool) + init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=float) init_x = (init_obs, init_done) # Initialise hidden states. @@ -663,7 +658,7 @@ def learner_setup( config["arch"]["num_envs"], config["system"]["num_agents"], ), - dtype=bool, + dtype=float, ) hstates = HiddenStates(policy_hstates, critic_hstates) params = Params(actor_params, critic_params) diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py index afd4d4ac3..fae0c0190 100644 --- a/mava/systems/rec_mappo_rware.py +++ b/mava/systems/rec_mappo_rware.py @@ -44,20 +44,20 @@ ExperimentOutput, HiddenStates, LearnerFn, + ObservationGlobalState, OptStates, Params, PPOTransition, RecActorApply, RecCriticApply, + RnnGlobalObservation, RNNLearnerState, - RnnObservation, ) from mava.utils.logger_tools import get_sacred_exp from mava.wrappers.jumanji import ( AgentIDWrapper, GlobalStateWrapper, LogWrapper, - ObservationGlobalState, RwareWrapper, ) @@ -99,7 +99,7 @@ class Actor(nn.Module): def __call__( self, policy_hidden_state: chex.Array, - observation_done: RnnObservation, + observation_done: RnnGlobalObservation, ) -> Tuple[chex.Array, distrax.Categorical]: """Forward pass.""" observation, done = observation_done @@ -138,7 +138,7 @@ class Critic(nn.Module): def __call__( self, critic_hidden_state: Tuple[chex.Array, chex.Array], - observation_done: RnnObservation, + observation_done: RnnGlobalObservation, ) -> Tuple[chex.Array, chex.Array]: """Forward pass.""" observation, done = observation_done @@ -229,12 +229,7 @@ def _env_step( env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # log episode return and length - done = jax.tree_util.tree_map( - lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( - config["arch"]["num_envs"], -1 - ), - timestep.last(), - ) + done = 1 - timestep.discount info = { "episode_return": env_state.episode_return_info, "episode_length": env_state.episode_length_info, @@ -590,7 +585,7 @@ def learner_setup( init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs) # Select only a single agent - init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=bool) + init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=float) init_obs_single = ObservationGlobalState( agents_view=init_obs.agents_view[:, :, 0, :], action_mask=init_obs.action_mask[:, :, 0, :], @@ -679,7 +674,7 @@ def learner_setup( config["arch"]["num_envs"], config["system"]["num_agents"], ), - dtype=bool, + dtype=float, ) hstates = HiddenStates(policy_hstates, critic_hstates) params = Params(actor_params, critic_params) diff --git a/mava/types.py b/mava/types.py index 7455e9b2f..ecb523631 100644 --- a/mava/types.py +++ b/mava/types.py @@ -71,7 +71,8 @@ class LogEnvState: episode_length_info: chex.Numeric -RnnObservation: TypeAlias = Tuple[ObservationGlobalState, Done] +RnnObservation: TypeAlias = Tuple[Observation, Done] +RnnGlobalObservation: TypeAlias = Tuple[ObservationGlobalState, Done] class PPOTransition(NamedTuple): diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index aec4c0ed0..a0cb23903 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -150,11 +150,10 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: action_mask=timestep.observation.action_mask, step_count=jnp.repeat(timestep.observation.step_count, n_agents), ) - reward = jnp.repeat(timestep.reward, n_agents) - # discount = jnp.repeat(timestep.discount, n_agents) - # -> we won't need this if we'll use the timestep.last() for the 'done' - # variable during training. - return timestep.replace(observation=observation, reward=reward) + shared_reward = jnp.sum(timestep.reward) + reward = jnp.repeat(shared_reward, n_agents) + discount = jnp.repeat(timestep.discount, n_agents) + return timestep.replace(observation=observation, reward=reward, discount=discount) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """Reset the environment. Updates the step count.""" From a6df77e0785b38067b85f26e6097865de3152c53 Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Thu, 23 Nov 2023 17:38:23 +0100 Subject: [PATCH 06/10] revert: revert the changes made on 'done' variable --- mava/systems/ff_ippo_rware.py | 7 ++++++- mava/systems/ff_mappo_rware.py | 7 ++++++- mava/systems/rec_ippo_rware.py | 11 ++++++++--- mava/systems/rec_mappo_rware.py | 11 ++++++++--- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py index 6902897dc..eae1bee31 100644 --- a/mava/systems/ff_ippo_rware.py +++ b/mava/systems/ff_ippo_rware.py @@ -152,7 +152,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # LOG EPISODE METRICS - done = 1 - timestep.discount + done = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( + config["arch"]["num_envs"], -1 + ), + timestep.last(), + ) info = { "episode_return": env_state.episode_return_info, "episode_length": env_state.episode_length_info, diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py index 131ee12af..29e92f5e0 100644 --- a/mava/systems/ff_mappo_rware.py +++ b/mava/systems/ff_mappo_rware.py @@ -157,7 +157,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # LOG EPISODE METRICS - done = 1 - timestep.discount + done = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( + config["arch"]["num_envs"], -1 + ), + timestep.last(), + ) info = { "episode_return": env_state.episode_return_info, "episode_length": env_state.episode_length_info, diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py index 9dc280e1f..f151d2cba 100644 --- a/mava/systems/rec_ippo_rware.py +++ b/mava/systems/rec_ippo_rware.py @@ -229,7 +229,12 @@ def _env_step( env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # log episode return and length - done = 1 - timestep.discount + done = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( + config["arch"]["num_envs"], -1 + ), + timestep.last(), + ) info = { "episode_return": env_state.episode_return_info, "episode_length": env_state.episode_length_info, @@ -581,7 +586,7 @@ def learner_setup( init_obs, ) init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs) - init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=float) + init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=bool) init_x = (init_obs, init_done) # Initialise hidden states. @@ -658,7 +663,7 @@ def learner_setup( config["arch"]["num_envs"], config["system"]["num_agents"], ), - dtype=float, + dtype=bool, ) hstates = HiddenStates(policy_hstates, critic_hstates) params = Params(actor_params, critic_params) diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py index fae0c0190..7cfa727e6 100644 --- a/mava/systems/rec_mappo_rware.py +++ b/mava/systems/rec_mappo_rware.py @@ -229,7 +229,12 @@ def _env_step( env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # log episode return and length - done = 1 - timestep.discount + done = jax.tree_util.tree_map( + lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape( + config["arch"]["num_envs"], -1 + ), + timestep.last(), + ) info = { "episode_return": env_state.episode_return_info, "episode_length": env_state.episode_length_info, @@ -585,7 +590,7 @@ def learner_setup( init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs) # Select only a single agent - init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=float) + init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=bool) init_obs_single = ObservationGlobalState( agents_view=init_obs.agents_view[:, :, 0, :], action_mask=init_obs.action_mask[:, :, 0, :], @@ -674,7 +679,7 @@ def learner_setup( config["arch"]["num_envs"], config["system"]["num_agents"], ), - dtype=float, + dtype=bool, ) hstates = HiddenStates(policy_hstates, critic_hstates) params = Params(actor_params, critic_params) From 7547931e0bf819ef99377620403cb52c7467d2bc Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Fri, 24 Nov 2023 11:06:38 +0100 Subject: [PATCH 07/10] revert: revert some changes + doc changing based on review --- mava/configs/system/ff_ippo.yaml | 2 +- mava/types.py | 20 +++++++++++--------- mava/wrappers/jumanji.py | 3 +-- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mava/configs/system/ff_ippo.yaml b/mava/configs/system/ff_ippo.yaml index 9b62c7cfe..ab91c66b2 100644 --- a/mava/configs/system/ff_ippo.yaml +++ b/mava/configs/system/ff_ippo.yaml @@ -1,6 +1,6 @@ # --- Defaults FF-IPPO --- -num_updates: 10000 # Number of updates +num_updates: 1000 # Number of updates seed: 42 # --- Agent observations --- diff --git a/mava/types.py b/mava/types.py index ecb523631..ffe8ecb51 100644 --- a/mava/types.py +++ b/mava/types.py @@ -36,21 +36,23 @@ class Observation(NamedTuple): + """The observation that the agent sees. + agents_view: the agents' view of other agents and shelves within their + sensor range. The number of features in the observation array + depends on the sensor range of the agent. + action_mask: boolean array specifying, for each agent, which action is legal. + step_count: the number of steps elapsed since the beginning of the episode. + """ + agents_view: chex.Array action_mask: chex.Array step_count: chex.Numeric class ObservationGlobalState(NamedTuple): - """The observation that the agent sees. - agents_view: the agents' view of other agents and shelves within their - sensor range. The number of features in the observation array - depends on the sensor range of the agent. - action_mask: boolean array specifying, for each agent, which action - (up, right, down, left) is legal. - global_state: the global state of the environment, which is the - concatenation of the agents' views. - step_count: the number of steps elapsed since the beginning of the episode. + """The observation seen by agents in centralized systems. + Extends `Observation` by adding a `global_state` attribute for centralized training. + global_state: The global state of the environment, often a concatenation of agents' views. """ agents_view: chex.Array # (num_agents, num_obs_features) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index a0cb23903..fc777e3b1 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -150,8 +150,7 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: action_mask=timestep.observation.action_mask, step_count=jnp.repeat(timestep.observation.step_count, n_agents), ) - shared_reward = jnp.sum(timestep.reward) - reward = jnp.repeat(shared_reward, n_agents) + reward = jnp.repeat(timestep.reward, n_agents) discount = jnp.repeat(timestep.discount, n_agents) return timestep.replace(observation=observation, reward=reward, discount=discount) From 7e500aaa2c657ce818689c589020f39284391da4 Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Mon, 27 Nov 2023 15:37:36 +0100 Subject: [PATCH 08/10] fix: small edits --- mava/wrappers/jumanji.py | 86 ++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 42 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index fc777e3b1..48bac789b 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -60,6 +60,46 @@ def step( return state, timestep +class RwareWrapper(Wrapper): + """Multi-agent wrapper for the Robotic Warehouse environment.""" + + def __init__(self, env: RobotWarehouse): + super().__init__(env) + self._env: RobotWarehouse + + def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: + n_agents = self._env.num_agents + observation = Observation( + agents_view=timestep.observation.agents_view, + action_mask=timestep.observation.action_mask, + step_count=jnp.repeat(timestep.observation.step_count, n_agents), + ) + reward = jnp.repeat(timestep.reward, n_agents) + discount = jnp.repeat(timestep.discount, n_agents) + return timestep.replace(observation=observation, reward=reward, discount=discount) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + """Reset the environment. Updates the step count.""" + state, timestep = self._env.reset(key) + return state, self.modify_timestep(timestep) + + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: + """Step the environment. Updates the step count.""" + state, timestep = self._env.step(state, action) + return state, self.modify_timestep(timestep) + + def observation_spec(self) -> specs.Spec[Observation]: + """Specification of the observation of the `RobotWarehouse` environment.""" + step_count = specs.BoundedArray( + (self._env.num_agents,), + jnp.int32, + [0] * self._env.num_agents, + [self._env.time_limit] * self._env.num_agents, + "step_count", + ) + return self._env.observation_spec().replace(step_count=step_count) + + class AgentIDWrapper(Wrapper): """Add onehot agent IDs to observation.""" @@ -112,7 +152,9 @@ def step( return state, timestep - def observation_spec(self) -> specs.Spec[Observation]: + def observation_spec( + self, + ) -> Union[specs.Spec[Observation], specs.Spec[ObservationGlobalState]]: """Specification of the observation of the `RobotWarehouse` environment.""" agents_view = specs.Array( (self._env.num_agents, self.num_obs_features), jnp.int32, "agents_view" @@ -136,46 +178,6 @@ def observation_spec(self) -> specs.Spec[Observation]: ) -class RwareWrapper(Wrapper): - """Multi-agent wrapper for the Robotic Warehouse environment.""" - - def __init__(self, env: RobotWarehouse): - super().__init__(env) - self._env: RobotWarehouse - - def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: - n_agents = self._env.num_agents - observation = Observation( - agents_view=timestep.observation.agents_view, - action_mask=timestep.observation.action_mask, - step_count=jnp.repeat(timestep.observation.step_count, n_agents), - ) - reward = jnp.repeat(timestep.reward, n_agents) - discount = jnp.repeat(timestep.discount, n_agents) - return timestep.replace(observation=observation, reward=reward, discount=discount) - - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: - """Reset the environment. Updates the step count.""" - state, timestep = self._env.reset(key) - return state, self.modify_timestep(timestep) - - def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: - """Step the environment. Updates the step count.""" - state, timestep = self._env.step(state, action) - return state, self.modify_timestep(timestep) - - def observation_spec(self) -> specs.Spec[Observation]: - """Specification of the observation of the `RobotWarehouse` environment.""" - step_count = specs.BoundedArray( - (self._env.num_agents,), - jnp.int32, - [0] * self._env.num_agents, - [self._env.time_limit] * self._env.num_agents, - "step_count", - ) - return self._env.observation_spec().replace(step_count=step_count) - - class GlobalStateWrapper(Wrapper): """Wrapper for adding global state to an environment that follows the mava API. @@ -184,7 +186,7 @@ class GlobalStateWrapper(Wrapper): by concatenating the observations of all agents. """ - def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep) -> TimeStep[ObservationGlobalState]: global_state = jnp.concatenate(timestep.observation.agents_view, axis=0) global_state = jnp.tile(global_state, (self._env.num_agents, 1)) From e6b51eb86752ea14afb9a3854917fe786ed54ca9 Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Wed, 29 Nov 2023 12:50:29 +0100 Subject: [PATCH 09/10] docs: small edits based on review --- mava/systems/rec_ippo_rware.py | 6 +++--- mava/systems/rec_mappo_rware.py | 6 +++--- mava/types.py | 26 +++++++++++++------------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py index 224d61337..9678b3715 100644 --- a/mava/systems/rec_ippo_rware.py +++ b/mava/systems/rec_ippo_rware.py @@ -48,7 +48,7 @@ RecActorApply, RecCriticApply, RNNLearnerState, - RnnObservation, + RNNObservation, ) from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper @@ -90,7 +90,7 @@ class Actor(nn.Module): def __call__( self, policy_hidden_state: chex.Array, - observation_done: RnnObservation, + observation_done: RNNObservation, ) -> Tuple[chex.Array, distrax.Categorical]: """Forward pass.""" observation, done = observation_done @@ -129,7 +129,7 @@ class Critic(nn.Module): def __call__( self, critic_hidden_state: Tuple[chex.Array, chex.Array], - observation_done: RnnObservation, + observation_done: RNNObservation, ) -> Tuple[chex.Array, chex.Array]: """Forward pass.""" observation, done = observation_done diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py index abcac0486..9765bc1b8 100644 --- a/mava/systems/rec_mappo_rware.py +++ b/mava/systems/rec_mappo_rware.py @@ -48,7 +48,7 @@ PPOTransition, RecActorApply, RecCriticApply, - RnnGlobalObservation, + RNNGlobalObservation, RNNLearnerState, ) from mava.wrappers.jumanji import ( @@ -96,7 +96,7 @@ class Actor(nn.Module): def __call__( self, policy_hidden_state: chex.Array, - observation_done: RnnGlobalObservation, + observation_done: RNNGlobalObservation, ) -> Tuple[chex.Array, distrax.Categorical]: """Forward pass.""" observation, done = observation_done @@ -135,7 +135,7 @@ class Critic(nn.Module): def __call__( self, critic_hidden_state: Tuple[chex.Array, chex.Array], - observation_done: RnnGlobalObservation, + observation_done: RNNGlobalObservation, ) -> Tuple[chex.Array, chex.Array]: """Forward pass.""" observation, done = observation_done diff --git a/mava/types.py b/mava/types.py index ffe8ecb51..1d47e0958 100644 --- a/mava/types.py +++ b/mava/types.py @@ -37,27 +37,27 @@ class Observation(NamedTuple): """The observation that the agent sees. - agents_view: the agents' view of other agents and shelves within their - sensor range. The number of features in the observation array - depends on the sensor range of the agent. + agents_view: the agents' view of other agents and items within their + field of view (fov). The number of features in the observation array + depends on the number of elemnts that can be seen in the fov of the agent. action_mask: boolean array specifying, for each agent, which action is legal. step_count: the number of steps elapsed since the beginning of the episode. """ - agents_view: chex.Array - action_mask: chex.Array - step_count: chex.Numeric + agents_view: chex.Array # (num_agents, num_obs_features) + action_mask: chex.Array # (num_agents, num_actions) + step_count: chex.Array # (num_agents, ) class ObservationGlobalState(NamedTuple): - """The observation seen by agents in centralized systems. - Extends `Observation` by adding a `global_state` attribute for centralized training. + """The observation seen by agents in centralised systems. + Extends `Observation` by adding a `global_state` attribute for centralised training. global_state: The global state of the environment, often a concatenation of agents' views. """ agents_view: chex.Array # (num_agents, num_obs_features) action_mask: chex.Array # (num_agents, num_actions) - global_state: chex.Array # (num_agents, num_agents * num_obs_features, ) + global_state: chex.Array # (num_agents, num_agents * num_obs_features) step_count: chex.Array # (num_agents, ) @@ -73,8 +73,8 @@ class LogEnvState: episode_length_info: chex.Numeric -RnnObservation: TypeAlias = Tuple[Observation, Done] -RnnGlobalObservation: TypeAlias = Tuple[ObservationGlobalState, Done] +RNNObservation: TypeAlias = Tuple[Observation, Done] +RNNGlobalObservation: TypeAlias = Tuple[ObservationGlobalState, Done] class PPOTransition(NamedTuple): @@ -176,6 +176,6 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] RecActorApply = Callable[ - [FrozenDict, HiddenState, RnnObservation], Tuple[HiddenState, Distribution] + [FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Distribution] ] -RecCriticApply = Callable[[FrozenDict, HiddenState, RnnObservation], Tuple[HiddenState, Value]] +RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]] From b3ed0e5f83c334e6e7d9b5a62e2c8b119555d861 Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Wed, 29 Nov 2023 17:26:02 +0100 Subject: [PATCH 10/10] feat(shaed.py): move the shared wrappers in a seperate file --- mava/systems/ff_ippo_rware.py | 3 +- mava/systems/ff_mappo_rware.py | 8 +- mava/systems/rec_ippo_rware.py | 3 +- mava/systems/rec_mappo_rware.py | 8 +- mava/types.py | 4 +- mava/wrappers/jumanji.py | 171 +---------------------------- mava/wrappers/shared.py | 188 ++++++++++++++++++++++++++++++++ 7 files changed, 200 insertions(+), 185 deletions(-) create mode 100644 mava/wrappers/shared.py diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py index cf93a4bc9..875c93f5b 100644 --- a/mava/systems/ff_ippo_rware.py +++ b/mava/systems/ff_ippo_rware.py @@ -49,7 +49,8 @@ PPOTransition, ) from mava.utils.jax import merge_leading_dims -from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper +from mava.wrappers.jumanji import RwareWrapper +from mava.wrappers.shared import AgentIDWrapper, LogWrapper class Actor(nn.Module): diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py index fd6845271..e06aca90e 100644 --- a/mava/systems/ff_mappo_rware.py +++ b/mava/systems/ff_mappo_rware.py @@ -49,12 +49,8 @@ PPOTransition, ) from mava.utils.jax import merge_leading_dims -from mava.wrappers.jumanji import ( - AgentIDWrapper, - GlobalStateWrapper, - LogWrapper, - RwareWrapper, -) +from mava.wrappers.jumanji import RwareWrapper +from mava.wrappers.shared import AgentIDWrapper, GlobalStateWrapper, LogWrapper class Actor(nn.Module): diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py index 9678b3715..4c1ee3636 100644 --- a/mava/systems/rec_ippo_rware.py +++ b/mava/systems/rec_ippo_rware.py @@ -50,7 +50,8 @@ RNNLearnerState, RNNObservation, ) -from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper +from mava.wrappers.jumanji import RwareWrapper +from mava.wrappers.shared import AgentIDWrapper, LogWrapper class ScannedRNN(nn.Module): diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py index 9765bc1b8..8aa49d716 100644 --- a/mava/systems/rec_mappo_rware.py +++ b/mava/systems/rec_mappo_rware.py @@ -51,12 +51,8 @@ RNNGlobalObservation, RNNLearnerState, ) -from mava.wrappers.jumanji import ( - AgentIDWrapper, - GlobalStateWrapper, - LogWrapper, - RwareWrapper, -) +from mava.wrappers.jumanji import RwareWrapper +from mava.wrappers.shared import AgentIDWrapper, GlobalStateWrapper, LogWrapper class ScannedRNN(nn.Module): diff --git a/mava/types.py b/mava/types.py index 1d47e0958..cf5ead1d8 100644 --- a/mava/types.py +++ b/mava/types.py @@ -37,9 +37,7 @@ class Observation(NamedTuple): """The observation that the agent sees. - agents_view: the agents' view of other agents and items within their - field of view (fov). The number of features in the observation array - depends on the number of elemnts that can be seen in the fov of the agent. + agents_view: the agent's view of the environment. action_mask: boolean array specifying, for each agent, which action is legal. step_count: the number of steps elapsed since the beginning of the episode. """ diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 48bac789b..e757a6f33 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -12,52 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Tuple import chex import jax.numpy as jnp from jumanji import specs -from jumanji.env import Environment -from jumanji.environments.routing.robot_warehouse import RobotWarehouse, State +from jumanji.environments.routing.robot_warehouse import RobotWarehouse from jumanji.types import TimeStep from jumanji.wrappers import Wrapper -from mava.types import LogEnvState, Observation, ObservationGlobalState - - -class LogWrapper(Wrapper): - """Log the episode returns and lengths.""" - - def reset(self, key: chex.PRNGKey) -> Tuple[LogEnvState, TimeStep]: - """Reset the environment.""" - state, timestep = self._env.reset(key) - state = LogEnvState(state, jnp.float32(0.0), 0, jnp.float32(0.0), 0) - return state, timestep - - def step( - self, - state: LogEnvState, - action: chex.Array, - ) -> Tuple[LogEnvState, TimeStep]: - """Step the environment.""" - env_state, timestep = self._env.step(state.env_state, action) - - done = timestep.last() - not_done = 1 - done - - new_episode_return = state.episode_returns + jnp.mean(timestep.reward) - new_episode_length = state.episode_lengths + 1 - episode_return_info = state.episode_return_info * not_done + new_episode_return * done - episode_length_info = state.episode_length_info * not_done + new_episode_length * done - - state = LogEnvState( - env_state=env_state, - episode_returns=new_episode_return * not_done, - episode_lengths=new_episode_length * not_done, - episode_return_info=episode_return_info, - episode_length_info=episode_length_info, - ) - return state, timestep +from mava.types import Observation, State class RwareWrapper(Wrapper): @@ -98,132 +62,3 @@ def observation_spec(self) -> specs.Spec[Observation]: "step_count", ) return self._env.observation_spec().replace(step_count=step_count) - - -class AgentIDWrapper(Wrapper): - """Add onehot agent IDs to observation.""" - - def __init__(self, env: Environment, has_global_state: bool = False): - super().__init__(env) - self.num_obs_features = self._env.num_obs_features + self._env.num_agents - self.has_global_state = has_global_state - - def _add_agent_ids( - self, timestep: TimeStep, num_agents: int - ) -> Union[Observation, ObservationGlobalState]: - agent_ids = jnp.eye(num_agents) - new_agents_view = jnp.concatenate([agent_ids, timestep.observation.agents_view], axis=-1) - - if self.has_global_state: - # Add the agent IDs to the global state - new_global_state = jnp.concatenate( - [agent_ids, timestep.observation.global_state], axis=-1 - ) - - return ObservationGlobalState( - agents_view=new_agents_view, - action_mask=timestep.observation.action_mask, - step_count=timestep.observation.step_count, - global_state=new_global_state, - ) - - else: - return Observation( - agents_view=new_agents_view, - action_mask=timestep.observation.action_mask, - step_count=timestep.observation.step_count, - ) - - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: - """Reset the environment.""" - state, timestep = self._env.reset(key) - timestep.observation = self._add_agent_ids(timestep, self._env.num_agents) - - return state, timestep - - def step( - self, - state: State, - action: chex.Array, - ) -> Tuple[State, TimeStep]: - """Step the environment.""" - state, timestep = self._env.step(state, action) - timestep.observation = self._add_agent_ids(timestep, self._env.num_agents) - - return state, timestep - - def observation_spec( - self, - ) -> Union[specs.Spec[Observation], specs.Spec[ObservationGlobalState]]: - """Specification of the observation of the `RobotWarehouse` environment.""" - agents_view = specs.Array( - (self._env.num_agents, self.num_obs_features), jnp.int32, "agents_view" - ) - global_state = specs.Array( - ( - self._env.num_agents, - self._env.num_obs_features * self._env.num_agents + self._env.num_agents, - ), - jnp.int32, - "global_state", - ) - - if self.has_global_state: - return self._env.observation_spec().replace( - agents_view=agents_view, - global_state=global_state, - ) - return self._env.observation_spec().replace( - agents_view=agents_view, - ) - - -class GlobalStateWrapper(Wrapper): - """Wrapper for adding global state to an environment that follows the mava API. - - The wrapper includes a global environment state to be used by the centralised critic. - Note here that since most environments do not have a global state, we create one - by concatenating the observations of all agents. - """ - - def modify_timestep(self, timestep: TimeStep) -> TimeStep[ObservationGlobalState]: - global_state = jnp.concatenate(timestep.observation.agents_view, axis=0) - global_state = jnp.tile(global_state, (self._env.num_agents, 1)) - - observation = ObservationGlobalState( - global_state=global_state, - agents_view=timestep.observation.agents_view, - action_mask=timestep.observation.action_mask, - step_count=timestep.observation.step_count, - ) - - return timestep.replace(observation=observation) - - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: - """Reset the environment. Updates the step count.""" - state, timestep = self._env.reset(key) - return state, self.modify_timestep(timestep) - - def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: - """Step the environment. Updates the step count.""" - state, timestep = self._env.step(state, action) - return state, self.modify_timestep(timestep) - - def observation_spec(self) -> specs.Spec[ObservationGlobalState]: - """Specification of the observation of the `RobotWarehouse` environment.""" - - obs_spec = self._env.observation_spec() - global_state = specs.Array( - (self._env.num_agents, self._env.num_agents * self._env.num_obs_features), - jnp.int32, - "global_state", - ) - - return specs.Spec( - ObservationGlobalState, - "ObservationSpec", - agents_view=obs_spec.agents_view, - action_mask=obs_spec.action_mask, - global_state=global_state, - step_count=obs_spec.step_count, - ) diff --git a/mava/wrappers/shared.py b/mava/wrappers/shared.py new file mode 100644 index 000000000..459a169e3 --- /dev/null +++ b/mava/wrappers/shared.py @@ -0,0 +1,188 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import chex +import jax.numpy as jnp +from jumanji import specs +from jumanji.env import Environment +from jumanji.types import TimeStep +from jumanji.wrappers import Wrapper + +from mava.types import LogEnvState, Observation, ObservationGlobalState, State + + +class LogWrapper(Wrapper): + """Log the episode returns and lengths.""" + + def reset(self, key: chex.PRNGKey) -> Tuple[LogEnvState, TimeStep]: + """Reset the environment.""" + state, timestep = self._env.reset(key) + state = LogEnvState(state, jnp.float32(0.0), 0, jnp.float32(0.0), 0) + return state, timestep + + def step( + self, + state: LogEnvState, + action: chex.Array, + ) -> Tuple[LogEnvState, TimeStep]: + """Step the environment.""" + env_state, timestep = self._env.step(state.env_state, action) + + done = timestep.last() + not_done = 1 - done + + new_episode_return = state.episode_returns + jnp.mean(timestep.reward) + new_episode_length = state.episode_lengths + 1 + episode_return_info = state.episode_return_info * not_done + new_episode_return * done + episode_length_info = state.episode_length_info * not_done + new_episode_length * done + + state = LogEnvState( + env_state=env_state, + episode_returns=new_episode_return * not_done, + episode_lengths=new_episode_length * not_done, + episode_return_info=episode_return_info, + episode_length_info=episode_length_info, + ) + return state, timestep + + +class AgentIDWrapper(Wrapper): + """Add onehot agent IDs to observation.""" + + def __init__(self, env: Environment, has_global_state: bool = False): + super().__init__(env) + self.num_obs_features = self._env.num_obs_features + self._env.num_agents + self.has_global_state = has_global_state + + def _add_agent_ids( + self, timestep: TimeStep, num_agents: int + ) -> Union[Observation, ObservationGlobalState]: + agent_ids = jnp.eye(num_agents) + new_agents_view = jnp.concatenate([agent_ids, timestep.observation.agents_view], axis=-1) + + if self.has_global_state: + # Add the agent IDs to the global state + new_global_state = jnp.concatenate( + [agent_ids, timestep.observation.global_state], axis=-1 + ) + + return ObservationGlobalState( + agents_view=new_agents_view, + action_mask=timestep.observation.action_mask, + step_count=timestep.observation.step_count, + global_state=new_global_state, + ) + + else: + return Observation( + agents_view=new_agents_view, + action_mask=timestep.observation.action_mask, + step_count=timestep.observation.step_count, + ) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + """Reset the environment.""" + state, timestep = self._env.reset(key) + timestep.observation = self._add_agent_ids(timestep, self._env.num_agents) + + return state, timestep + + def step( + self, + state: State, + action: chex.Array, + ) -> Tuple[State, TimeStep]: + """Step the environment.""" + state, timestep = self._env.step(state, action) + timestep.observation = self._add_agent_ids(timestep, self._env.num_agents) + + return state, timestep + + def observation_spec( + self, + ) -> Union[specs.Spec[Observation], specs.Spec[ObservationGlobalState]]: + """Specification of the observation of the `RobotWarehouse` environment.""" + agents_view = specs.Array( + (self._env.num_agents, self.num_obs_features), jnp.int32, "agents_view" + ) + global_state = specs.Array( + ( + self._env.num_agents, + self._env.num_obs_features * self._env.num_agents + self._env.num_agents, + ), + jnp.int32, + "global_state", + ) + + if self.has_global_state: + return self._env.observation_spec().replace( + agents_view=agents_view, + global_state=global_state, + ) + return self._env.observation_spec().replace( + agents_view=agents_view, + ) + + +class GlobalStateWrapper(Wrapper): + """Wrapper for adding global state to an environment that follows the mava API. + + The wrapper includes a global environment state to be used by the centralised critic. + Note here that since most environments do not have a global state, we create one + by concatenating the observations of all agents. + """ + + def modify_timestep(self, timestep: TimeStep) -> TimeStep[ObservationGlobalState]: + global_state = jnp.concatenate(timestep.observation.agents_view, axis=0) + global_state = jnp.tile(global_state, (self._env.num_agents, 1)) + + observation = ObservationGlobalState( + global_state=global_state, + agents_view=timestep.observation.agents_view, + action_mask=timestep.observation.action_mask, + step_count=timestep.observation.step_count, + ) + + return timestep.replace(observation=observation) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + """Reset the environment. Updates the step count.""" + state, timestep = self._env.reset(key) + return state, self.modify_timestep(timestep) + + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: + """Step the environment. Updates the step count.""" + state, timestep = self._env.step(state, action) + return state, self.modify_timestep(timestep) + + def observation_spec(self) -> specs.Spec[ObservationGlobalState]: + """Specification of the observation of the `RobotWarehouse` environment.""" + + obs_spec = self._env.observation_spec() + global_state = specs.Array( + (self._env.num_agents, self._env.num_agents * self._env.num_obs_features), + jnp.int32, + "global_state", + ) + + return specs.Spec( + ObservationGlobalState, + "ObservationSpec", + agents_view=obs_spec.agents_view, + action_mask=obs_spec.action_mask, + global_state=global_state, + step_count=obs_spec.step_count, + )