From ff7d345dc2383df1129051d312fae607a6703805 Mon Sep 17 00:00:00 2001
From: Sasha Abramowitz
Date: Mon, 20 Nov 2023 18:13:19 +0200
Subject: [PATCH 01/10] feat: general wrapper mostly complete
Global state wrapper not yet working properly
---
mava/systems/ff_ippo_rware.py | 8 ++--
mava/systems/ff_mappo_rware.py | 9 ++--
mava/systems/rec_ippo_rware.py | 11 +++--
mava/systems/rec_mappo_rware.py | 12 +++--
mava/types.py | 17 ++++---
mava/wrappers/jumanji.py | 84 +++++++++++++++------------------
6 files changed, 70 insertions(+), 71 deletions(-)
diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py
index bf8648f94..c4df33c87 100644
--- a/mava/systems/ff_ippo_rware.py
+++ b/mava/systems/ff_ippo_rware.py
@@ -30,7 +30,6 @@
from flax.linen.initializers import constant, orthogonal
from jumanji.env import Environment
from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator
-from jumanji.types import Observation
from jumanji.wrappers import AutoResetWrapper
from omegaconf import DictConfig, OmegaConf
from optax._src.base import OptState
@@ -44,6 +43,7 @@
ExperimentOutput,
LearnerFn,
LearnerState,
+ Observation,
OptStates,
Params,
PPOTransition,
@@ -51,7 +51,7 @@
from mava.utils.jax import merge_leading_dims
from mava.utils.logger_tools import get_sacred_exp
from mava.utils.timing_utils import TimeIt
-from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareMultiAgentWrapper
+from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper
class Actor(nn.Module):
@@ -497,14 +497,14 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None:
# Create envs
generator = RandomGenerator(**config["rware_scenario"]["task_config"])
env = jumanji.make(config["env_name"], generator=generator)
- env = RwareMultiAgentWrapper(env)
+ env = RwareWrapper(env)
# Add agent id to observation.
if config["add_agent_id"]:
env = AgentIDWrapper(env)
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env_name"], generator=generator)
- eval_env = RwareMultiAgentWrapper(eval_env)
+ eval_env = RwareWrapper(eval_env)
if config["add_agent_id"]:
eval_env = AgentIDWrapper(eval_env)
diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py
index cafb98dff..10d6e5159 100644
--- a/mava/systems/ff_mappo_rware.py
+++ b/mava/systems/ff_mappo_rware.py
@@ -29,7 +29,6 @@
from flax.linen.initializers import constant, orthogonal
from jumanji.env import Environment
from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator
-from jumanji.types import Observation
from jumanji.wrappers import AutoResetWrapper
from omegaconf import DictConfig, OmegaConf
from optax._src.base import OptState
@@ -43,6 +42,7 @@
ExperimentOutput,
LearnerFn,
LearnerState,
+ Observation,
OptStates,
Params,
PPOTransition,
@@ -52,9 +52,10 @@
from mava.utils.timing_utils import TimeIt
from mava.wrappers.jumanji import (
AgentIDWrapper,
+ GlobalStateWrapper,
LogWrapper,
ObservationGlobalState,
- RwareMultiAgentWithGlobalStateWrapper,
+ RwareWrapper,
)
@@ -506,14 +507,14 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None:
# Create envs
generator = RandomGenerator(**config["rware_scenario"]["task_config"])
env = jumanji.make(config["env_name"], generator=generator)
- env = RwareMultiAgentWithGlobalStateWrapper(env)
+ env = GlobalStateWrapper(RwareWrapper(env))
# Add agent id to observation.
if config["add_agent_id"]:
env = AgentIDWrapper(env=env, has_global_state=True)
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env_name"], generator=generator)
- eval_env = RwareMultiAgentWithGlobalStateWrapper(eval_env)
+ eval_env = GlobalStateWrapper(RwareWrapper(env))
if config["add_agent_id"]:
eval_env = AgentIDWrapper(env=eval_env, has_global_state=True)
diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py
index be1976966..9f5f50df8 100644
--- a/mava/systems/rec_ippo_rware.py
+++ b/mava/systems/rec_ippo_rware.py
@@ -48,10 +48,11 @@
RecActorApply,
RecCriticApply,
RNNLearnerState,
+ RnnObservation,
)
from mava.utils.logger_tools import get_sacred_exp
from mava.utils.timing_utils import TimeIt
-from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareMultiAgentWrapper
+from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper
class ScannedRNN(nn.Module):
@@ -91,7 +92,7 @@ class Actor(nn.Module):
def __call__(
self,
policy_hidden_state: chex.Array,
- observation_done: Tuple[chex.Array, chex.Array],
+ observation_done: RnnObservation,
) -> Tuple[chex.Array, distrax.Categorical]:
"""Forward pass."""
observation, done = observation_done
@@ -130,7 +131,7 @@ class Critic(nn.Module):
def __call__(
self,
critic_hidden_state: Tuple[chex.Array, chex.Array],
- observation_done: Tuple[chex.Array, chex.Array],
+ observation_done: RnnObservation,
) -> Tuple[chex.Array, chex.Array]:
"""Forward pass."""
observation, done = observation_done
@@ -677,14 +678,14 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None:
# Create envs
generator = RandomGenerator(**config["rware_scenario"]["task_config"])
env = jumanji.make(config["env_name"], generator=generator)
- env = RwareMultiAgentWrapper(env)
+ env = RwareWrapper(env)
# Add agent id to observation.
if config["add_agent_id"]:
env = AgentIDWrapper(env)
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env_name"], generator=generator)
- eval_env = RwareMultiAgentWrapper(eval_env)
+ eval_env = RwareWrapper(eval_env)
if config["add_agent_id"]:
eval_env = AgentIDWrapper(eval_env)
diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py
index a53390040..e92146af6 100644
--- a/mava/systems/rec_mappo_rware.py
+++ b/mava/systems/rec_mappo_rware.py
@@ -48,14 +48,16 @@
RecActorApply,
RecCriticApply,
RNNLearnerState,
+ RnnObservation,
)
from mava.utils.logger_tools import get_sacred_exp
from mava.utils.timing_utils import TimeIt
from mava.wrappers.jumanji import (
AgentIDWrapper,
+ GlobalStateWrapper,
LogWrapper,
ObservationGlobalState,
- RwareMultiAgentWithGlobalStateWrapper,
+ RwareWrapper,
)
@@ -96,7 +98,7 @@ class Actor(nn.Module):
def __call__(
self,
policy_hidden_state: chex.Array,
- observation_done: Tuple[chex.Array, chex.Array],
+ observation_done: RnnObservation,
) -> Tuple[chex.Array, distrax.Categorical]:
"""Forward pass."""
observation, done = observation_done
@@ -135,7 +137,7 @@ class Critic(nn.Module):
def __call__(
self,
critic_hidden_state: Tuple[chex.Array, chex.Array],
- observation_done: Tuple[chex.Array, chex.Array],
+ observation_done: RnnObservation,
) -> Tuple[chex.Array, chex.Array]:
"""Forward pass."""
observation, done = observation_done
@@ -686,14 +688,14 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None:
# Create envs
generator = RandomGenerator(**config["rware_scenario"]["task_config"])
env = jumanji.make(config["env_name"], generator=generator)
- env = RwareMultiAgentWithGlobalStateWrapper(env)
+ env = GlobalStateWrapper(RwareWrapper(env))
# Add agent id to observation.
if config["add_agent_id"]:
env = AgentIDWrapper(env=env, has_global_state=True)
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env_name"], generator=generator)
- eval_env = RwareMultiAgentWithGlobalStateWrapper(eval_env)
+ eval_env = GlobalStateWrapper(RwareWrapper(env))
if config["add_agent_id"]:
eval_env = AgentIDWrapper(env=eval_env, has_global_state=True)
diff --git a/mava/types.py b/mava/types.py
index d7db2f30e..9da2adf29 100644
--- a/mava/types.py
+++ b/mava/types.py
@@ -21,16 +21,21 @@
from optax._src.base import OptState
from typing_extensions import NamedTuple, TypeAlias
-from mava.wrappers.jumanji import LogEnvState
-
Action: TypeAlias = chex.Array
Value: TypeAlias = chex.Array
Done: TypeAlias = chex.Array
HiddenState: TypeAlias = chex.Array
-# Can't know the exact type of State or Timestep.
+# Can't know the exact type of State.
State: TypeAlias = Any
-Observation: TypeAlias = Any
+
+
+class Observation(NamedTuple):
+ agents_view: chex.Array
+ action_mask: chex.Array
+ step_count: chex.Numeric
+
+
RnnObservation: TypeAlias = Tuple[Observation, Done]
@@ -73,7 +78,7 @@ class LearnerState(NamedTuple):
params: Params
opt_states: OptStates
key: chex.PRNGKey
- env_state: LogEnvState
+ env_state: State
timestep: TimeStep
@@ -83,7 +88,7 @@ class RNNLearnerState(NamedTuple):
params: Params
opt_states: OptStates
key: chex.PRNGKey
- env_state: LogEnvState
+ env_state: State
timestep: TimeStep
dones: Done
hstates: HiddenStates
diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py
index 847c0cde5..206096d79 100644
--- a/mava/wrappers/jumanji.py
+++ b/mava/wrappers/jumanji.py
@@ -18,10 +18,12 @@
import jax.numpy as jnp
from jumanji import specs
from jumanji.env import Environment
-from jumanji.environments.routing.robot_warehouse import Observation, State
+from jumanji.environments.routing.robot_warehouse import RobotWarehouse, State
from jumanji.types import TimeStep
from jumanji.wrappers import Wrapper
+from mava.types import Observation
+
if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
@@ -169,28 +171,31 @@ def observation_spec(self) -> specs.Spec[Observation]:
)
-class RwareMultiAgentWrapper(Wrapper):
+class RwareWrapper(Wrapper):
"""Multi-agent wrapper for the Robotic Warehouse environment."""
- def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
- """Reset the environment. Updates the step count."""
- state, timestep = self._env.reset(key)
- timestep.observation = Observation(
+ def __init__(self, env: RobotWarehouse):
+ super().__init__(env)
+ self._env: RobotWarehouse
+
+ def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]:
+ observation = Observation(
agents_view=timestep.observation.agents_view,
action_mask=timestep.observation.action_mask,
step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents),
)
- return state, timestep
+
+ return timestep.replace(observation=observation)
+
+ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
+ """Reset the environment. Updates the step count."""
+ state, timestep = self._env.reset(key)
+ return state, self.convert_timstep(timestep)
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
"""Step the environment. Updates the step count."""
state, timestep = self._env.step(state, action)
- timestep.observation = Observation(
- agents_view=timestep.observation.agents_view,
- action_mask=timestep.observation.action_mask,
- step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents),
- )
- return state, timestep
+ return state, self.convert_timstep(timestep)
def observation_spec(self) -> specs.Spec[Observation]:
"""Specification of the observation of the `RobotWarehouse` environment."""
@@ -204,67 +209,52 @@ def observation_spec(self) -> specs.Spec[Observation]:
return self._env.observation_spec().replace(step_count=step_count)
-class RwareMultiAgentWithGlobalStateWrapper(Wrapper):
- """Multi-agent wrapper for the Robotic Warehouse environment.
+class GlobalStateWrapper(Wrapper):
+ """Wrapper for adding global state to an environment that follows the mava API.
The wrapper includes a global environment state to be used by the centralised critic.
Note here that since robotic warehouse does not have a global state, we create one
by concatenating the observations of all agents.
"""
- def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
- """Reset the environment. Updates the step count."""
- state, timestep = self._env.reset(key)
+ def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]:
global_state = jnp.concatenate(timestep.observation.agents_view, axis=0)
global_state = jnp.tile(global_state, (self._env.num_agents, 1))
- timestep.observation = ObservationGlobalState(
+
+ observation = ObservationGlobalState(
+ global_state=global_state,
agents_view=timestep.observation.agents_view,
action_mask=timestep.observation.action_mask,
- global_state=global_state,
- step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents),
+ step_count=timestep.observation.step_count,
)
- return state, timestep
+
+ return timestep.replace(observation=observation)
+
+ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
+ """Reset the environment. Updates the step count."""
+ state, timestep = self._env.reset(key)
+ return state, self.convert_timstep(timestep)
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
"""Step the environment. Updates the step count."""
state, timestep = self._env.step(state, action)
- global_state = jnp.concatenate(timestep.observation.agents_view, axis=0)
- global_state = jnp.tile(global_state, (self._env.num_agents, 1))
- timestep.observation = ObservationGlobalState(
- agents_view=timestep.observation.agents_view,
- action_mask=timestep.observation.action_mask,
- global_state=global_state,
- step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents),
- )
- return state, timestep
+ return state, self.convert_timstep(timestep)
def observation_spec(self) -> specs.Spec[ObservationGlobalState]:
"""Specification of the observation of the `RobotWarehouse` environment."""
- agents_view = specs.Array(
- (self._env.num_agents, self._env.num_obs_features), jnp.int32, "agents_view"
- )
- action_mask = specs.BoundedArray(
- (self._env.num_agents, 5), bool, False, True, "action_mask"
- )
+ obs_spec = self._env.observation_spec()
global_state = specs.Array(
(self._env.num_agents, self._env.num_agents * self._env.num_obs_features),
jnp.int32,
"global_state",
)
- step_count = specs.BoundedArray(
- (self._env.num_agents,),
- jnp.int32,
- [0] * self._env.num_agents,
- [self._env.time_limit] * self._env.num_agents,
- "step_count",
- )
return specs.Spec(
ObservationGlobalState,
"ObservationSpec",
- agents_view=agents_view,
- action_mask=action_mask,
+ agents_view=obs_spec.agents_view,
+ action_mask=obs_spec.action_mask,
global_state=global_state,
- step_count=step_count,
+ step_count=obs_spec.step_count,
)
From b1f33fa1ab613c9d903161724c1bdf333e2487f7 Mon Sep 17 00:00:00 2001
From: Sasha Abramowitz
Date: Tue, 21 Nov 2023 10:45:05 +0200
Subject: [PATCH 02/10] fix: centralized critic eval bug with new wrapper
---
mava/systems/ff_mappo_rware.py | 2 +-
mava/systems/rec_mappo_rware.py | 2 +-
mava/wrappers/jumanji.py | 21 +++++++++++++--------
3 files changed, 15 insertions(+), 10 deletions(-)
diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py
index 10d6e5159..e6dc4a4de 100644
--- a/mava/systems/ff_mappo_rware.py
+++ b/mava/systems/ff_mappo_rware.py
@@ -514,7 +514,7 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None:
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env_name"], generator=generator)
- eval_env = GlobalStateWrapper(RwareWrapper(env))
+ eval_env = GlobalStateWrapper(RwareWrapper(eval_env))
if config["add_agent_id"]:
eval_env = AgentIDWrapper(env=eval_env, has_global_state=True)
diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py
index e92146af6..b69223375 100644
--- a/mava/systems/rec_mappo_rware.py
+++ b/mava/systems/rec_mappo_rware.py
@@ -695,7 +695,7 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None:
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env_name"], generator=generator)
- eval_env = GlobalStateWrapper(RwareWrapper(env))
+ eval_env = GlobalStateWrapper(RwareWrapper(eval_env))
if config["add_agent_id"]:
eval_env = AgentIDWrapper(env=eval_env, has_global_state=True)
diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py
index 206096d79..ee14153a1 100644
--- a/mava/wrappers/jumanji.py
+++ b/mava/wrappers/jumanji.py
@@ -178,24 +178,29 @@ def __init__(self, env: RobotWarehouse):
super().__init__(env)
self._env: RobotWarehouse
- def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]:
+ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
+ n_agents = self._env.num_agents
observation = Observation(
agents_view=timestep.observation.agents_view,
action_mask=timestep.observation.action_mask,
- step_count=jnp.repeat(timestep.observation.step_count, self._env.num_agents),
+ step_count=jnp.repeat(timestep.observation.step_count, n_agents),
)
+ # todo (weim): get this working so that ppo always expects (n_agents,) size for reward and discount
+ # reward = jnp.repeat(timestep.reward, n_agents)
+ # discount = jnp.repeat(timestep.discount, n_agents)
+ # return timestep.replace(observation=observation, reward=reward, discount=discount)
return timestep.replace(observation=observation)
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
"""Reset the environment. Updates the step count."""
state, timestep = self._env.reset(key)
- return state, self.convert_timstep(timestep)
+ return state, self.modify_timestep(timestep)
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
"""Step the environment. Updates the step count."""
state, timestep = self._env.step(state, action)
- return state, self.convert_timstep(timestep)
+ return state, self.modify_timestep(timestep)
def observation_spec(self) -> specs.Spec[Observation]:
"""Specification of the observation of the `RobotWarehouse` environment."""
@@ -213,11 +218,11 @@ class GlobalStateWrapper(Wrapper):
"""Wrapper for adding global state to an environment that follows the mava API.
The wrapper includes a global environment state to be used by the centralised critic.
- Note here that since robotic warehouse does not have a global state, we create one
+ Note here that since most environments do not have a global state, we create one
by concatenating the observations of all agents.
"""
- def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]:
+ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
global_state = jnp.concatenate(timestep.observation.agents_view, axis=0)
global_state = jnp.tile(global_state, (self._env.num_agents, 1))
@@ -233,12 +238,12 @@ def convert_timstep(self, timestep: TimeStep) -> TimeStep[Observation]:
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
"""Reset the environment. Updates the step count."""
state, timestep = self._env.reset(key)
- return state, self.convert_timstep(timestep)
+ return state, self.modify_timestep(timestep)
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
"""Step the environment. Updates the step count."""
state, timestep = self._env.step(state, action)
- return state, self.convert_timstep(timestep)
+ return state, self.modify_timestep(timestep)
def observation_spec(self) -> specs.Spec[ObservationGlobalState]:
"""Specification of the observation of the `RobotWarehouse` environment."""
From b520cf391ba798dc79df3306170c9d1624c4ad7f Mon Sep 17 00:00:00 2001
From: Sasha Abramowitz
Date: Tue, 21 Nov 2023 20:34:11 +0200
Subject: [PATCH 03/10] feat: repeat reward and discount in wrapper
---
mava/evaluator.py | 4 +++-
mava/systems/ff_ippo_rware.py | 7 ++-----
mava/wrappers/jumanji.py | 9 +++------
3 files changed, 8 insertions(+), 12 deletions(-)
diff --git a/mava/evaluator.py b/mava/evaluator.py
index ce6c7c221..5c740304b 100644
--- a/mava/evaluator.py
+++ b/mava/evaluator.py
@@ -105,7 +105,9 @@ def evaluator_fn(trained_params: FrozenDict, rng: chex.PRNGKey) -> ExperimentOut
# Add dimension to pmap over.
step_rngs = jnp.stack(step_rngs).reshape(eval_batch, -1)
- eval_state = EvalState(step_rngs, env_states, timesteps, 0, 0.0)
+ eval_state = EvalState(
+ step_rngs, env_states, timesteps, 0, jnp.zeros_like(timesteps.reward)
+ )
eval_metrics = jax.vmap(
eval_one_episode,
in_axes=(None, EvalState(0, 0, 0, None, None)),
diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py
index c4df33c87..09686a616 100644
--- a/mava/systems/ff_ippo_rware.py
+++ b/mava/systems/ff_ippo_rware.py
@@ -151,17 +151,14 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# LOG EPISODE METRICS
- done, reward = jax.tree_util.tree_map(
- lambda x: jnp.repeat(x, config["num_agents"]).reshape(config["num_envs"], -1),
- (timestep.last(), timestep.reward),
- )
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
}
+ done = 1 - timestep.discount
transition = PPOTransition(
- done, action, value, reward, log_prob, last_timestep.observation, info
+ done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
learner_state = LearnerState(params, opt_states, rng, env_state, timestep)
return learner_state, transition
diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py
index ee14153a1..18ab1482a 100644
--- a/mava/wrappers/jumanji.py
+++ b/mava/wrappers/jumanji.py
@@ -185,12 +185,9 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
action_mask=timestep.observation.action_mask,
step_count=jnp.repeat(timestep.observation.step_count, n_agents),
)
- # todo (weim): get this working so that ppo always expects (n_agents,) size for reward and discount
- # reward = jnp.repeat(timestep.reward, n_agents)
- # discount = jnp.repeat(timestep.discount, n_agents)
- # return timestep.replace(observation=observation, reward=reward, discount=discount)
-
- return timestep.replace(observation=observation)
+ reward = jnp.repeat(timestep.reward, n_agents)
+ discount = jnp.repeat(timestep.discount, n_agents)
+ return timestep.replace(observation=observation, reward=reward, discount=discount)
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
"""Reset the environment. Updates the step count."""
From 4236d33448fb9d640080d230c62a50e3a1a0d88f Mon Sep 17 00:00:00 2001
From: WiemKhlifi
Date: Wed, 22 Nov 2023 19:32:44 +0100
Subject: [PATCH 04/10] feat: edit types and wrapper for other systems
---
mava/configs/arch/anakin.yaml | 2 +-
mava/evaluator.py | 2 +-
mava/systems/ff_ippo_rware.py | 8 +++++-
mava/systems/ff_mappo_rware.py | 8 +++---
mava/systems/rec_ippo_rware.py | 6 ++---
mava/systems/rec_mappo_rware.py | 6 ++---
mava/types.py | 41 +++++++++++++++++++++++++++---
mava/wrappers/jumanji.py | 45 +++++----------------------------
8 files changed, 62 insertions(+), 56 deletions(-)
diff --git a/mava/configs/arch/anakin.yaml b/mava/configs/arch/anakin.yaml
index 86e75898b..ef9aa4314 100644
--- a/mava/configs/arch/anakin.yaml
+++ b/mava/configs/arch/anakin.yaml
@@ -1,7 +1,7 @@
# --- Anakin config ---
# --- Training ---
-num_envs: 16 # Number of vectorised environments per device.
+num_envs: 256 # Number of vectorised environments per device.
# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
diff --git a/mava/evaluator.py b/mava/evaluator.py
index 7c2fa04c7..ac77c7723 100644
--- a/mava/evaluator.py
+++ b/mava/evaluator.py
@@ -233,7 +233,7 @@ def evaluator_fn(
dones=dones,
hstate=init_hstate,
step_count_=0,
- return_=0.0,
+ return_=jnp.zeros_like(timesteps.reward),
)
eval_metrics = jax.vmap(
diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py
index 0e16c3916..14e0b7529 100644
--- a/mava/systems/ff_ippo_rware.py
+++ b/mava/systems/ff_ippo_rware.py
@@ -157,7 +157,13 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
"episode_length": env_state.episode_length_info,
}
- done = 1 - timestep.discount
+ # done = 1 - timestep.discount
+ done = jax.tree_util.tree_map(
+ lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
+ config["arch"]["num_envs"], -1
+ ),
+ timestep.last(),
+ )
transition = PPOTransition(
done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py
index f5021bd5a..00b824d2b 100644
--- a/mava/systems/ff_mappo_rware.py
+++ b/mava/systems/ff_mappo_rware.py
@@ -157,11 +157,11 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# LOG EPISODE METRICS
- done, reward = jax.tree_util.tree_map(
+ done = jax.tree_util.tree_map(
lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
config["arch"]["num_envs"], -1
),
- (timestep.last(), timestep.reward),
+ timestep.last(),
)
info = {
"episode_return": env_state.episode_return_info,
@@ -169,7 +169,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
}
transition = PPOTransition(
- done, action, value, reward, log_prob, last_timestep.observation, info
+ done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
learner_state = LearnerState(params, opt_states, rng, env_state, timestep)
return learner_state, transition
@@ -526,7 +526,7 @@ def run_experiment(_run: run.Run, _config: Dict, _log: SacredLogger) -> None:
env = LogWrapper(env)
eval_env = jumanji.make(config["env"]["env_name"], generator=generator)
eval_env = GlobalStateWrapper(RwareWrapper(eval_env))
- if config["add_agent_id"]:
+ if config["system"]["add_agent_id"]:
eval_env = AgentIDWrapper(env=eval_env, has_global_state=True)
# PRNG keys.
diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py
index 4a3ac847c..f151d2cba 100644
--- a/mava/systems/rec_ippo_rware.py
+++ b/mava/systems/rec_ippo_rware.py
@@ -229,11 +229,11 @@ def _env_step(
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# log episode return and length
- done, reward = jax.tree_util.tree_map(
+ done = jax.tree_util.tree_map(
lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
config["arch"]["num_envs"], -1
),
- (timestep.last(), timestep.reward),
+ timestep.last(),
)
info = {
"episode_return": env_state.episode_return_info,
@@ -241,7 +241,7 @@ def _env_step(
}
transition = PPOTransition(
- done, action, value, reward, log_prob, last_timestep.observation, info
+ done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
hstates = HiddenStates(policy_hidden_state, critic_hidden_state)
learner_state = RNNLearnerState(
diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py
index 05699c000..afd4d4ac3 100644
--- a/mava/systems/rec_mappo_rware.py
+++ b/mava/systems/rec_mappo_rware.py
@@ -229,11 +229,11 @@ def _env_step(
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# log episode return and length
- done, reward = jax.tree_util.tree_map(
+ done = jax.tree_util.tree_map(
lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
config["arch"]["num_envs"], -1
),
- (timestep.last(), timestep.reward),
+ timestep.last(),
)
info = {
"episode_return": env_state.episode_return_info,
@@ -241,7 +241,7 @@ def _env_step(
}
transition = PPOTransition(
- done, action, value, reward, log_prob, last_timestep.observation, info
+ done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
hstates = HiddenStates(policy_hidden_state, critic_hidden_state)
learner_state = RNNLearnerState(
diff --git a/mava/types.py b/mava/types.py
index b888a3f16..7455e9b2f 100644
--- a/mava/types.py
+++ b/mava/types.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
import chex
from distrax import Distribution
@@ -21,13 +21,16 @@
from optax._src.base import OptState
from typing_extensions import NamedTuple, TypeAlias
-from mava.wrappers.jumanji import LogEnvState
+if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
+ from dataclasses import dataclass
+else:
+ from flax.struct import dataclass
+
Action: TypeAlias = chex.Array
Value: TypeAlias = chex.Array
Done: TypeAlias = chex.Array
HiddenState: TypeAlias = chex.Array
-
# Can't know the exact type of State.
State: TypeAlias = Any
@@ -38,7 +41,37 @@ class Observation(NamedTuple):
step_count: chex.Numeric
-RnnObservation: TypeAlias = Tuple[Observation, Done]
+class ObservationGlobalState(NamedTuple):
+ """The observation that the agent sees.
+ agents_view: the agents' view of other agents and shelves within their
+ sensor range. The number of features in the observation array
+ depends on the sensor range of the agent.
+ action_mask: boolean array specifying, for each agent, which action
+ (up, right, down, left) is legal.
+ global_state: the global state of the environment, which is the
+ concatenation of the agents' views.
+ step_count: the number of steps elapsed since the beginning of the episode.
+ """
+
+ agents_view: chex.Array # (num_agents, num_obs_features)
+ action_mask: chex.Array # (num_agents, num_actions)
+ global_state: chex.Array # (num_agents, num_agents * num_obs_features, )
+ step_count: chex.Array # (num_agents, )
+
+
+@dataclass
+class LogEnvState:
+ """State of the `LogWrapper`."""
+
+ env_state: State
+ episode_returns: chex.Numeric
+ episode_lengths: chex.Numeric
+ # Information about the episode return and length for logging purposes.
+ episode_return_info: chex.Numeric
+ episode_length_info: chex.Numeric
+
+
+RnnObservation: TypeAlias = Tuple[ObservationGlobalState, Done]
class PPOTransition(NamedTuple):
diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py
index 0b9e87475..aec4c0ed0 100644
--- a/mava/wrappers/jumanji.py
+++ b/mava/wrappers/jumanji.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, NamedTuple, Tuple, Union
+from typing import Tuple, Union
import chex
import jax.numpy as jnp
@@ -22,42 +22,7 @@
from jumanji.types import TimeStep
from jumanji.wrappers import Wrapper
-from mava.types import Observation
-
-if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
- from dataclasses import dataclass
-else:
- from flax.struct import dataclass
-
-
-class ObservationGlobalState(NamedTuple):
- """The observation that the agent sees.
- agents_view: the agents' view of other agents and shelves within their
- sensor range. The number of features in the observation array
- depends on the sensor range of the agent.
- action_mask: boolean array specifying, for each agent, which action
- (up, right, down, left) is legal.
- global_state: the global state of the environment, which is the
- concatenation of the agents' views.
- step_count: the number of steps elapsed since the beginning of the episode.
- """
-
- agents_view: chex.Array # (num_agents, num_obs_features)
- action_mask: chex.Array # (num_agents, num_actions)
- global_state: chex.Array # (num_agents, num_agents * num_obs_features, )
- step_count: chex.Array # (num_agents, )
-
-
-@dataclass
-class LogEnvState:
- """State of the `LogWrapper`."""
-
- env_state: State
- episode_returns: chex.Numeric
- episode_lengths: chex.Numeric
- # Information about the episode return and length for logging purposes.
- episode_return_info: chex.Numeric
- episode_length_info: chex.Numeric
+from mava.types import LogEnvState, Observation, ObservationGlobalState
class LogWrapper(Wrapper):
@@ -186,8 +151,10 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
step_count=jnp.repeat(timestep.observation.step_count, n_agents),
)
reward = jnp.repeat(timestep.reward, n_agents)
- discount = jnp.repeat(timestep.discount, n_agents)
- return timestep.replace(observation=observation, reward=reward, discount=discount)
+ # discount = jnp.repeat(timestep.discount, n_agents)
+ # -> we won't need this if we'll use the timestep.last() for the 'done'
+ # variable during training.
+ return timestep.replace(observation=observation, reward=reward)
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
"""Reset the environment. Updates the step count."""
From 3117226e1ddebefd82a55acc8100ed83d41f9376 Mon Sep 17 00:00:00 2001
From: WiemKhlifi
Date: Thu, 23 Nov 2023 15:43:54 +0100
Subject: [PATCH 05/10] feat: edit rec systems
---
docs/jumanji_rware_comparison.md | 2 +-
mava/systems/ff_ippo_rware.py | 8 +-------
mava/systems/ff_mappo_rware.py | 9 ++-------
mava/systems/rec_ippo_rware.py | 11 +++--------
mava/systems/rec_mappo_rware.py | 19 +++++++------------
mava/types.py | 3 ++-
mava/wrappers/jumanji.py | 9 ++++-----
7 files changed, 20 insertions(+), 41 deletions(-)
diff --git a/docs/jumanji_rware_comparison.md b/docs/jumanji_rware_comparison.md
index 0613c9ad3..3d041ef12 100644
--- a/docs/jumanji_rware_comparison.md
+++ b/docs/jumanji_rware_comparison.md
@@ -66,7 +66,7 @@ Please see below for Mava's recurrent and feedforward implementations of IPPO an
- Mava recurrent IPPO performance on the tiny-2ag
, tiny-4ag
and small-4ag
RWARE tasks.
+ Mava recurrent MAPPO performance on the tiny-2ag
, tiny-4ag
and small-4ag
RWARE tasks.
diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py
index 14e0b7529..6902897dc 100644
--- a/mava/systems/ff_ippo_rware.py
+++ b/mava/systems/ff_ippo_rware.py
@@ -152,18 +152,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# LOG EPISODE METRICS
+ done = 1 - timestep.discount
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
}
- # done = 1 - timestep.discount
- done = jax.tree_util.tree_map(
- lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
- config["arch"]["num_envs"], -1
- ),
- timestep.last(),
- )
transition = PPOTransition(
done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py
index 00b824d2b..131ee12af 100644
--- a/mava/systems/ff_mappo_rware.py
+++ b/mava/systems/ff_mappo_rware.py
@@ -45,6 +45,7 @@
LearnerFn,
LearnerState,
Observation,
+ ObservationGlobalState,
OptStates,
Params,
PPOTransition,
@@ -55,7 +56,6 @@
AgentIDWrapper,
GlobalStateWrapper,
LogWrapper,
- ObservationGlobalState,
RwareWrapper,
)
@@ -157,12 +157,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# LOG EPISODE METRICS
- done = jax.tree_util.tree_map(
- lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
- config["arch"]["num_envs"], -1
- ),
- timestep.last(),
- )
+ done = 1 - timestep.discount
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py
index f151d2cba..9dc280e1f 100644
--- a/mava/systems/rec_ippo_rware.py
+++ b/mava/systems/rec_ippo_rware.py
@@ -229,12 +229,7 @@ def _env_step(
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# log episode return and length
- done = jax.tree_util.tree_map(
- lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
- config["arch"]["num_envs"], -1
- ),
- timestep.last(),
- )
+ done = 1 - timestep.discount
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
@@ -586,7 +581,7 @@ def learner_setup(
init_obs,
)
init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs)
- init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=bool)
+ init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=float)
init_x = (init_obs, init_done)
# Initialise hidden states.
@@ -663,7 +658,7 @@ def learner_setup(
config["arch"]["num_envs"],
config["system"]["num_agents"],
),
- dtype=bool,
+ dtype=float,
)
hstates = HiddenStates(policy_hstates, critic_hstates)
params = Params(actor_params, critic_params)
diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py
index afd4d4ac3..fae0c0190 100644
--- a/mava/systems/rec_mappo_rware.py
+++ b/mava/systems/rec_mappo_rware.py
@@ -44,20 +44,20 @@
ExperimentOutput,
HiddenStates,
LearnerFn,
+ ObservationGlobalState,
OptStates,
Params,
PPOTransition,
RecActorApply,
RecCriticApply,
+ RnnGlobalObservation,
RNNLearnerState,
- RnnObservation,
)
from mava.utils.logger_tools import get_sacred_exp
from mava.wrappers.jumanji import (
AgentIDWrapper,
GlobalStateWrapper,
LogWrapper,
- ObservationGlobalState,
RwareWrapper,
)
@@ -99,7 +99,7 @@ class Actor(nn.Module):
def __call__(
self,
policy_hidden_state: chex.Array,
- observation_done: RnnObservation,
+ observation_done: RnnGlobalObservation,
) -> Tuple[chex.Array, distrax.Categorical]:
"""Forward pass."""
observation, done = observation_done
@@ -138,7 +138,7 @@ class Critic(nn.Module):
def __call__(
self,
critic_hidden_state: Tuple[chex.Array, chex.Array],
- observation_done: RnnObservation,
+ observation_done: RnnGlobalObservation,
) -> Tuple[chex.Array, chex.Array]:
"""Forward pass."""
observation, done = observation_done
@@ -229,12 +229,7 @@ def _env_step(
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# log episode return and length
- done = jax.tree_util.tree_map(
- lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
- config["arch"]["num_envs"], -1
- ),
- timestep.last(),
- )
+ done = 1 - timestep.discount
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
@@ -590,7 +585,7 @@ def learner_setup(
init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs)
# Select only a single agent
- init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=bool)
+ init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=float)
init_obs_single = ObservationGlobalState(
agents_view=init_obs.agents_view[:, :, 0, :],
action_mask=init_obs.action_mask[:, :, 0, :],
@@ -679,7 +674,7 @@ def learner_setup(
config["arch"]["num_envs"],
config["system"]["num_agents"],
),
- dtype=bool,
+ dtype=float,
)
hstates = HiddenStates(policy_hstates, critic_hstates)
params = Params(actor_params, critic_params)
diff --git a/mava/types.py b/mava/types.py
index 7455e9b2f..ecb523631 100644
--- a/mava/types.py
+++ b/mava/types.py
@@ -71,7 +71,8 @@ class LogEnvState:
episode_length_info: chex.Numeric
-RnnObservation: TypeAlias = Tuple[ObservationGlobalState, Done]
+RnnObservation: TypeAlias = Tuple[Observation, Done]
+RnnGlobalObservation: TypeAlias = Tuple[ObservationGlobalState, Done]
class PPOTransition(NamedTuple):
diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py
index aec4c0ed0..a0cb23903 100644
--- a/mava/wrappers/jumanji.py
+++ b/mava/wrappers/jumanji.py
@@ -150,11 +150,10 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
action_mask=timestep.observation.action_mask,
step_count=jnp.repeat(timestep.observation.step_count, n_agents),
)
- reward = jnp.repeat(timestep.reward, n_agents)
- # discount = jnp.repeat(timestep.discount, n_agents)
- # -> we won't need this if we'll use the timestep.last() for the 'done'
- # variable during training.
- return timestep.replace(observation=observation, reward=reward)
+ shared_reward = jnp.sum(timestep.reward)
+ reward = jnp.repeat(shared_reward, n_agents)
+ discount = jnp.repeat(timestep.discount, n_agents)
+ return timestep.replace(observation=observation, reward=reward, discount=discount)
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
"""Reset the environment. Updates the step count."""
From a6df77e0785b38067b85f26e6097865de3152c53 Mon Sep 17 00:00:00 2001
From: WiemKhlifi
Date: Thu, 23 Nov 2023 17:38:23 +0100
Subject: [PATCH 06/10] revert: revert the changes made on 'done' variable
---
mava/systems/ff_ippo_rware.py | 7 ++++++-
mava/systems/ff_mappo_rware.py | 7 ++++++-
mava/systems/rec_ippo_rware.py | 11 ++++++++---
mava/systems/rec_mappo_rware.py | 11 ++++++++---
4 files changed, 28 insertions(+), 8 deletions(-)
diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py
index 6902897dc..eae1bee31 100644
--- a/mava/systems/ff_ippo_rware.py
+++ b/mava/systems/ff_ippo_rware.py
@@ -152,7 +152,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# LOG EPISODE METRICS
- done = 1 - timestep.discount
+ done = jax.tree_util.tree_map(
+ lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
+ config["arch"]["num_envs"], -1
+ ),
+ timestep.last(),
+ )
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py
index 131ee12af..29e92f5e0 100644
--- a/mava/systems/ff_mappo_rware.py
+++ b/mava/systems/ff_mappo_rware.py
@@ -157,7 +157,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# LOG EPISODE METRICS
- done = 1 - timestep.discount
+ done = jax.tree_util.tree_map(
+ lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
+ config["arch"]["num_envs"], -1
+ ),
+ timestep.last(),
+ )
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py
index 9dc280e1f..f151d2cba 100644
--- a/mava/systems/rec_ippo_rware.py
+++ b/mava/systems/rec_ippo_rware.py
@@ -229,7 +229,12 @@ def _env_step(
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# log episode return and length
- done = 1 - timestep.discount
+ done = jax.tree_util.tree_map(
+ lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
+ config["arch"]["num_envs"], -1
+ ),
+ timestep.last(),
+ )
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
@@ -581,7 +586,7 @@ def learner_setup(
init_obs,
)
init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs)
- init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=float)
+ init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=bool)
init_x = (init_obs, init_done)
# Initialise hidden states.
@@ -658,7 +663,7 @@ def learner_setup(
config["arch"]["num_envs"],
config["system"]["num_agents"],
),
- dtype=float,
+ dtype=bool,
)
hstates = HiddenStates(policy_hstates, critic_hstates)
params = Params(actor_params, critic_params)
diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py
index fae0c0190..7cfa727e6 100644
--- a/mava/systems/rec_mappo_rware.py
+++ b/mava/systems/rec_mappo_rware.py
@@ -229,7 +229,12 @@ def _env_step(
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
# log episode return and length
- done = 1 - timestep.discount
+ done = jax.tree_util.tree_map(
+ lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
+ config["arch"]["num_envs"], -1
+ ),
+ timestep.last(),
+ )
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
@@ -585,7 +590,7 @@ def learner_setup(
init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs)
# Select only a single agent
- init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=float)
+ init_done = jnp.zeros((1, config["arch"]["num_envs"]), dtype=bool)
init_obs_single = ObservationGlobalState(
agents_view=init_obs.agents_view[:, :, 0, :],
action_mask=init_obs.action_mask[:, :, 0, :],
@@ -674,7 +679,7 @@ def learner_setup(
config["arch"]["num_envs"],
config["system"]["num_agents"],
),
- dtype=float,
+ dtype=bool,
)
hstates = HiddenStates(policy_hstates, critic_hstates)
params = Params(actor_params, critic_params)
From 7547931e0bf819ef99377620403cb52c7467d2bc Mon Sep 17 00:00:00 2001
From: WiemKhlifi
Date: Fri, 24 Nov 2023 11:06:38 +0100
Subject: [PATCH 07/10] revert: revert some changes + doc changing based on
review
---
mava/configs/system/ff_ippo.yaml | 2 +-
mava/types.py | 20 +++++++++++---------
mava/wrappers/jumanji.py | 3 +--
3 files changed, 13 insertions(+), 12 deletions(-)
diff --git a/mava/configs/system/ff_ippo.yaml b/mava/configs/system/ff_ippo.yaml
index 9b62c7cfe..ab91c66b2 100644
--- a/mava/configs/system/ff_ippo.yaml
+++ b/mava/configs/system/ff_ippo.yaml
@@ -1,6 +1,6 @@
# --- Defaults FF-IPPO ---
-num_updates: 10000 # Number of updates
+num_updates: 1000 # Number of updates
seed: 42
# --- Agent observations ---
diff --git a/mava/types.py b/mava/types.py
index ecb523631..ffe8ecb51 100644
--- a/mava/types.py
+++ b/mava/types.py
@@ -36,21 +36,23 @@
class Observation(NamedTuple):
+ """The observation that the agent sees.
+ agents_view: the agents' view of other agents and shelves within their
+ sensor range. The number of features in the observation array
+ depends on the sensor range of the agent.
+ action_mask: boolean array specifying, for each agent, which action is legal.
+ step_count: the number of steps elapsed since the beginning of the episode.
+ """
+
agents_view: chex.Array
action_mask: chex.Array
step_count: chex.Numeric
class ObservationGlobalState(NamedTuple):
- """The observation that the agent sees.
- agents_view: the agents' view of other agents and shelves within their
- sensor range. The number of features in the observation array
- depends on the sensor range of the agent.
- action_mask: boolean array specifying, for each agent, which action
- (up, right, down, left) is legal.
- global_state: the global state of the environment, which is the
- concatenation of the agents' views.
- step_count: the number of steps elapsed since the beginning of the episode.
+ """The observation seen by agents in centralized systems.
+ Extends `Observation` by adding a `global_state` attribute for centralized training.
+ global_state: The global state of the environment, often a concatenation of agents' views.
"""
agents_view: chex.Array # (num_agents, num_obs_features)
diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py
index a0cb23903..fc777e3b1 100644
--- a/mava/wrappers/jumanji.py
+++ b/mava/wrappers/jumanji.py
@@ -150,8 +150,7 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
action_mask=timestep.observation.action_mask,
step_count=jnp.repeat(timestep.observation.step_count, n_agents),
)
- shared_reward = jnp.sum(timestep.reward)
- reward = jnp.repeat(shared_reward, n_agents)
+ reward = jnp.repeat(timestep.reward, n_agents)
discount = jnp.repeat(timestep.discount, n_agents)
return timestep.replace(observation=observation, reward=reward, discount=discount)
From 7e500aaa2c657ce818689c589020f39284391da4 Mon Sep 17 00:00:00 2001
From: WiemKhlifi
Date: Mon, 27 Nov 2023 15:37:36 +0100
Subject: [PATCH 08/10] fix: small edits
---
mava/wrappers/jumanji.py | 86 ++++++++++++++++++++--------------------
1 file changed, 44 insertions(+), 42 deletions(-)
diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py
index fc777e3b1..48bac789b 100644
--- a/mava/wrappers/jumanji.py
+++ b/mava/wrappers/jumanji.py
@@ -60,6 +60,46 @@ def step(
return state, timestep
+class RwareWrapper(Wrapper):
+ """Multi-agent wrapper for the Robotic Warehouse environment."""
+
+ def __init__(self, env: RobotWarehouse):
+ super().__init__(env)
+ self._env: RobotWarehouse
+
+ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
+ n_agents = self._env.num_agents
+ observation = Observation(
+ agents_view=timestep.observation.agents_view,
+ action_mask=timestep.observation.action_mask,
+ step_count=jnp.repeat(timestep.observation.step_count, n_agents),
+ )
+ reward = jnp.repeat(timestep.reward, n_agents)
+ discount = jnp.repeat(timestep.discount, n_agents)
+ return timestep.replace(observation=observation, reward=reward, discount=discount)
+
+ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
+ """Reset the environment. Updates the step count."""
+ state, timestep = self._env.reset(key)
+ return state, self.modify_timestep(timestep)
+
+ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
+ """Step the environment. Updates the step count."""
+ state, timestep = self._env.step(state, action)
+ return state, self.modify_timestep(timestep)
+
+ def observation_spec(self) -> specs.Spec[Observation]:
+ """Specification of the observation of the `RobotWarehouse` environment."""
+ step_count = specs.BoundedArray(
+ (self._env.num_agents,),
+ jnp.int32,
+ [0] * self._env.num_agents,
+ [self._env.time_limit] * self._env.num_agents,
+ "step_count",
+ )
+ return self._env.observation_spec().replace(step_count=step_count)
+
+
class AgentIDWrapper(Wrapper):
"""Add onehot agent IDs to observation."""
@@ -112,7 +152,9 @@ def step(
return state, timestep
- def observation_spec(self) -> specs.Spec[Observation]:
+ def observation_spec(
+ self,
+ ) -> Union[specs.Spec[Observation], specs.Spec[ObservationGlobalState]]:
"""Specification of the observation of the `RobotWarehouse` environment."""
agents_view = specs.Array(
(self._env.num_agents, self.num_obs_features), jnp.int32, "agents_view"
@@ -136,46 +178,6 @@ def observation_spec(self) -> specs.Spec[Observation]:
)
-class RwareWrapper(Wrapper):
- """Multi-agent wrapper for the Robotic Warehouse environment."""
-
- def __init__(self, env: RobotWarehouse):
- super().__init__(env)
- self._env: RobotWarehouse
-
- def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
- n_agents = self._env.num_agents
- observation = Observation(
- agents_view=timestep.observation.agents_view,
- action_mask=timestep.observation.action_mask,
- step_count=jnp.repeat(timestep.observation.step_count, n_agents),
- )
- reward = jnp.repeat(timestep.reward, n_agents)
- discount = jnp.repeat(timestep.discount, n_agents)
- return timestep.replace(observation=observation, reward=reward, discount=discount)
-
- def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
- """Reset the environment. Updates the step count."""
- state, timestep = self._env.reset(key)
- return state, self.modify_timestep(timestep)
-
- def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
- """Step the environment. Updates the step count."""
- state, timestep = self._env.step(state, action)
- return state, self.modify_timestep(timestep)
-
- def observation_spec(self) -> specs.Spec[Observation]:
- """Specification of the observation of the `RobotWarehouse` environment."""
- step_count = specs.BoundedArray(
- (self._env.num_agents,),
- jnp.int32,
- [0] * self._env.num_agents,
- [self._env.time_limit] * self._env.num_agents,
- "step_count",
- )
- return self._env.observation_spec().replace(step_count=step_count)
-
-
class GlobalStateWrapper(Wrapper):
"""Wrapper for adding global state to an environment that follows the mava API.
@@ -184,7 +186,7 @@ class GlobalStateWrapper(Wrapper):
by concatenating the observations of all agents.
"""
- def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
+ def modify_timestep(self, timestep: TimeStep) -> TimeStep[ObservationGlobalState]:
global_state = jnp.concatenate(timestep.observation.agents_view, axis=0)
global_state = jnp.tile(global_state, (self._env.num_agents, 1))
From e6b51eb86752ea14afb9a3854917fe786ed54ca9 Mon Sep 17 00:00:00 2001
From: WiemKhlifi
Date: Wed, 29 Nov 2023 12:50:29 +0100
Subject: [PATCH 09/10] docs: small edits based on review
---
mava/systems/rec_ippo_rware.py | 6 +++---
mava/systems/rec_mappo_rware.py | 6 +++---
mava/types.py | 26 +++++++++++++-------------
3 files changed, 19 insertions(+), 19 deletions(-)
diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py
index 224d61337..9678b3715 100644
--- a/mava/systems/rec_ippo_rware.py
+++ b/mava/systems/rec_ippo_rware.py
@@ -48,7 +48,7 @@
RecActorApply,
RecCriticApply,
RNNLearnerState,
- RnnObservation,
+ RNNObservation,
)
from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper
@@ -90,7 +90,7 @@ class Actor(nn.Module):
def __call__(
self,
policy_hidden_state: chex.Array,
- observation_done: RnnObservation,
+ observation_done: RNNObservation,
) -> Tuple[chex.Array, distrax.Categorical]:
"""Forward pass."""
observation, done = observation_done
@@ -129,7 +129,7 @@ class Critic(nn.Module):
def __call__(
self,
critic_hidden_state: Tuple[chex.Array, chex.Array],
- observation_done: RnnObservation,
+ observation_done: RNNObservation,
) -> Tuple[chex.Array, chex.Array]:
"""Forward pass."""
observation, done = observation_done
diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py
index abcac0486..9765bc1b8 100644
--- a/mava/systems/rec_mappo_rware.py
+++ b/mava/systems/rec_mappo_rware.py
@@ -48,7 +48,7 @@
PPOTransition,
RecActorApply,
RecCriticApply,
- RnnGlobalObservation,
+ RNNGlobalObservation,
RNNLearnerState,
)
from mava.wrappers.jumanji import (
@@ -96,7 +96,7 @@ class Actor(nn.Module):
def __call__(
self,
policy_hidden_state: chex.Array,
- observation_done: RnnGlobalObservation,
+ observation_done: RNNGlobalObservation,
) -> Tuple[chex.Array, distrax.Categorical]:
"""Forward pass."""
observation, done = observation_done
@@ -135,7 +135,7 @@ class Critic(nn.Module):
def __call__(
self,
critic_hidden_state: Tuple[chex.Array, chex.Array],
- observation_done: RnnGlobalObservation,
+ observation_done: RNNGlobalObservation,
) -> Tuple[chex.Array, chex.Array]:
"""Forward pass."""
observation, done = observation_done
diff --git a/mava/types.py b/mava/types.py
index ffe8ecb51..1d47e0958 100644
--- a/mava/types.py
+++ b/mava/types.py
@@ -37,27 +37,27 @@
class Observation(NamedTuple):
"""The observation that the agent sees.
- agents_view: the agents' view of other agents and shelves within their
- sensor range. The number of features in the observation array
- depends on the sensor range of the agent.
+ agents_view: the agents' view of other agents and items within their
+ field of view (fov). The number of features in the observation array
+ depends on the number of elemnts that can be seen in the fov of the agent.
action_mask: boolean array specifying, for each agent, which action is legal.
step_count: the number of steps elapsed since the beginning of the episode.
"""
- agents_view: chex.Array
- action_mask: chex.Array
- step_count: chex.Numeric
+ agents_view: chex.Array # (num_agents, num_obs_features)
+ action_mask: chex.Array # (num_agents, num_actions)
+ step_count: chex.Array # (num_agents, )
class ObservationGlobalState(NamedTuple):
- """The observation seen by agents in centralized systems.
- Extends `Observation` by adding a `global_state` attribute for centralized training.
+ """The observation seen by agents in centralised systems.
+ Extends `Observation` by adding a `global_state` attribute for centralised training.
global_state: The global state of the environment, often a concatenation of agents' views.
"""
agents_view: chex.Array # (num_agents, num_obs_features)
action_mask: chex.Array # (num_agents, num_actions)
- global_state: chex.Array # (num_agents, num_agents * num_obs_features, )
+ global_state: chex.Array # (num_agents, num_agents * num_obs_features)
step_count: chex.Array # (num_agents, )
@@ -73,8 +73,8 @@ class LogEnvState:
episode_length_info: chex.Numeric
-RnnObservation: TypeAlias = Tuple[Observation, Done]
-RnnGlobalObservation: TypeAlias = Tuple[ObservationGlobalState, Done]
+RNNObservation: TypeAlias = Tuple[Observation, Done]
+RNNGlobalObservation: TypeAlias = Tuple[ObservationGlobalState, Done]
class PPOTransition(NamedTuple):
@@ -176,6 +176,6 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]):
ActorApply = Callable[[FrozenDict, Observation], Distribution]
CriticApply = Callable[[FrozenDict, Observation], Value]
RecActorApply = Callable[
- [FrozenDict, HiddenState, RnnObservation], Tuple[HiddenState, Distribution]
+ [FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Distribution]
]
-RecCriticApply = Callable[[FrozenDict, HiddenState, RnnObservation], Tuple[HiddenState, Value]]
+RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]]
From b3ed0e5f83c334e6e7d9b5a62e2c8b119555d861 Mon Sep 17 00:00:00 2001
From: WiemKhlifi
Date: Wed, 29 Nov 2023 17:26:02 +0100
Subject: [PATCH 10/10] feat(shaed.py): move the shared wrappers in a seperate
file
---
mava/systems/ff_ippo_rware.py | 3 +-
mava/systems/ff_mappo_rware.py | 8 +-
mava/systems/rec_ippo_rware.py | 3 +-
mava/systems/rec_mappo_rware.py | 8 +-
mava/types.py | 4 +-
mava/wrappers/jumanji.py | 171 +----------------------------
mava/wrappers/shared.py | 188 ++++++++++++++++++++++++++++++++
7 files changed, 200 insertions(+), 185 deletions(-)
create mode 100644 mava/wrappers/shared.py
diff --git a/mava/systems/ff_ippo_rware.py b/mava/systems/ff_ippo_rware.py
index cf93a4bc9..875c93f5b 100644
--- a/mava/systems/ff_ippo_rware.py
+++ b/mava/systems/ff_ippo_rware.py
@@ -49,7 +49,8 @@
PPOTransition,
)
from mava.utils.jax import merge_leading_dims
-from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper
+from mava.wrappers.jumanji import RwareWrapper
+from mava.wrappers.shared import AgentIDWrapper, LogWrapper
class Actor(nn.Module):
diff --git a/mava/systems/ff_mappo_rware.py b/mava/systems/ff_mappo_rware.py
index fd6845271..e06aca90e 100644
--- a/mava/systems/ff_mappo_rware.py
+++ b/mava/systems/ff_mappo_rware.py
@@ -49,12 +49,8 @@
PPOTransition,
)
from mava.utils.jax import merge_leading_dims
-from mava.wrappers.jumanji import (
- AgentIDWrapper,
- GlobalStateWrapper,
- LogWrapper,
- RwareWrapper,
-)
+from mava.wrappers.jumanji import RwareWrapper
+from mava.wrappers.shared import AgentIDWrapper, GlobalStateWrapper, LogWrapper
class Actor(nn.Module):
diff --git a/mava/systems/rec_ippo_rware.py b/mava/systems/rec_ippo_rware.py
index 9678b3715..4c1ee3636 100644
--- a/mava/systems/rec_ippo_rware.py
+++ b/mava/systems/rec_ippo_rware.py
@@ -50,7 +50,8 @@
RNNLearnerState,
RNNObservation,
)
-from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareWrapper
+from mava.wrappers.jumanji import RwareWrapper
+from mava.wrappers.shared import AgentIDWrapper, LogWrapper
class ScannedRNN(nn.Module):
diff --git a/mava/systems/rec_mappo_rware.py b/mava/systems/rec_mappo_rware.py
index 9765bc1b8..8aa49d716 100644
--- a/mava/systems/rec_mappo_rware.py
+++ b/mava/systems/rec_mappo_rware.py
@@ -51,12 +51,8 @@
RNNGlobalObservation,
RNNLearnerState,
)
-from mava.wrappers.jumanji import (
- AgentIDWrapper,
- GlobalStateWrapper,
- LogWrapper,
- RwareWrapper,
-)
+from mava.wrappers.jumanji import RwareWrapper
+from mava.wrappers.shared import AgentIDWrapper, GlobalStateWrapper, LogWrapper
class ScannedRNN(nn.Module):
diff --git a/mava/types.py b/mava/types.py
index 1d47e0958..cf5ead1d8 100644
--- a/mava/types.py
+++ b/mava/types.py
@@ -37,9 +37,7 @@
class Observation(NamedTuple):
"""The observation that the agent sees.
- agents_view: the agents' view of other agents and items within their
- field of view (fov). The number of features in the observation array
- depends on the number of elemnts that can be seen in the fov of the agent.
+ agents_view: the agent's view of the environment.
action_mask: boolean array specifying, for each agent, which action is legal.
step_count: the number of steps elapsed since the beginning of the episode.
"""
diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py
index 48bac789b..e757a6f33 100644
--- a/mava/wrappers/jumanji.py
+++ b/mava/wrappers/jumanji.py
@@ -12,52 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Tuple, Union
+from typing import Tuple
import chex
import jax.numpy as jnp
from jumanji import specs
-from jumanji.env import Environment
-from jumanji.environments.routing.robot_warehouse import RobotWarehouse, State
+from jumanji.environments.routing.robot_warehouse import RobotWarehouse
from jumanji.types import TimeStep
from jumanji.wrappers import Wrapper
-from mava.types import LogEnvState, Observation, ObservationGlobalState
-
-
-class LogWrapper(Wrapper):
- """Log the episode returns and lengths."""
-
- def reset(self, key: chex.PRNGKey) -> Tuple[LogEnvState, TimeStep]:
- """Reset the environment."""
- state, timestep = self._env.reset(key)
- state = LogEnvState(state, jnp.float32(0.0), 0, jnp.float32(0.0), 0)
- return state, timestep
-
- def step(
- self,
- state: LogEnvState,
- action: chex.Array,
- ) -> Tuple[LogEnvState, TimeStep]:
- """Step the environment."""
- env_state, timestep = self._env.step(state.env_state, action)
-
- done = timestep.last()
- not_done = 1 - done
-
- new_episode_return = state.episode_returns + jnp.mean(timestep.reward)
- new_episode_length = state.episode_lengths + 1
- episode_return_info = state.episode_return_info * not_done + new_episode_return * done
- episode_length_info = state.episode_length_info * not_done + new_episode_length * done
-
- state = LogEnvState(
- env_state=env_state,
- episode_returns=new_episode_return * not_done,
- episode_lengths=new_episode_length * not_done,
- episode_return_info=episode_return_info,
- episode_length_info=episode_length_info,
- )
- return state, timestep
+from mava.types import Observation, State
class RwareWrapper(Wrapper):
@@ -98,132 +62,3 @@ def observation_spec(self) -> specs.Spec[Observation]:
"step_count",
)
return self._env.observation_spec().replace(step_count=step_count)
-
-
-class AgentIDWrapper(Wrapper):
- """Add onehot agent IDs to observation."""
-
- def __init__(self, env: Environment, has_global_state: bool = False):
- super().__init__(env)
- self.num_obs_features = self._env.num_obs_features + self._env.num_agents
- self.has_global_state = has_global_state
-
- def _add_agent_ids(
- self, timestep: TimeStep, num_agents: int
- ) -> Union[Observation, ObservationGlobalState]:
- agent_ids = jnp.eye(num_agents)
- new_agents_view = jnp.concatenate([agent_ids, timestep.observation.agents_view], axis=-1)
-
- if self.has_global_state:
- # Add the agent IDs to the global state
- new_global_state = jnp.concatenate(
- [agent_ids, timestep.observation.global_state], axis=-1
- )
-
- return ObservationGlobalState(
- agents_view=new_agents_view,
- action_mask=timestep.observation.action_mask,
- step_count=timestep.observation.step_count,
- global_state=new_global_state,
- )
-
- else:
- return Observation(
- agents_view=new_agents_view,
- action_mask=timestep.observation.action_mask,
- step_count=timestep.observation.step_count,
- )
-
- def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
- """Reset the environment."""
- state, timestep = self._env.reset(key)
- timestep.observation = self._add_agent_ids(timestep, self._env.num_agents)
-
- return state, timestep
-
- def step(
- self,
- state: State,
- action: chex.Array,
- ) -> Tuple[State, TimeStep]:
- """Step the environment."""
- state, timestep = self._env.step(state, action)
- timestep.observation = self._add_agent_ids(timestep, self._env.num_agents)
-
- return state, timestep
-
- def observation_spec(
- self,
- ) -> Union[specs.Spec[Observation], specs.Spec[ObservationGlobalState]]:
- """Specification of the observation of the `RobotWarehouse` environment."""
- agents_view = specs.Array(
- (self._env.num_agents, self.num_obs_features), jnp.int32, "agents_view"
- )
- global_state = specs.Array(
- (
- self._env.num_agents,
- self._env.num_obs_features * self._env.num_agents + self._env.num_agents,
- ),
- jnp.int32,
- "global_state",
- )
-
- if self.has_global_state:
- return self._env.observation_spec().replace(
- agents_view=agents_view,
- global_state=global_state,
- )
- return self._env.observation_spec().replace(
- agents_view=agents_view,
- )
-
-
-class GlobalStateWrapper(Wrapper):
- """Wrapper for adding global state to an environment that follows the mava API.
-
- The wrapper includes a global environment state to be used by the centralised critic.
- Note here that since most environments do not have a global state, we create one
- by concatenating the observations of all agents.
- """
-
- def modify_timestep(self, timestep: TimeStep) -> TimeStep[ObservationGlobalState]:
- global_state = jnp.concatenate(timestep.observation.agents_view, axis=0)
- global_state = jnp.tile(global_state, (self._env.num_agents, 1))
-
- observation = ObservationGlobalState(
- global_state=global_state,
- agents_view=timestep.observation.agents_view,
- action_mask=timestep.observation.action_mask,
- step_count=timestep.observation.step_count,
- )
-
- return timestep.replace(observation=observation)
-
- def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
- """Reset the environment. Updates the step count."""
- state, timestep = self._env.reset(key)
- return state, self.modify_timestep(timestep)
-
- def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
- """Step the environment. Updates the step count."""
- state, timestep = self._env.step(state, action)
- return state, self.modify_timestep(timestep)
-
- def observation_spec(self) -> specs.Spec[ObservationGlobalState]:
- """Specification of the observation of the `RobotWarehouse` environment."""
-
- obs_spec = self._env.observation_spec()
- global_state = specs.Array(
- (self._env.num_agents, self._env.num_agents * self._env.num_obs_features),
- jnp.int32,
- "global_state",
- )
-
- return specs.Spec(
- ObservationGlobalState,
- "ObservationSpec",
- agents_view=obs_spec.agents_view,
- action_mask=obs_spec.action_mask,
- global_state=global_state,
- step_count=obs_spec.step_count,
- )
diff --git a/mava/wrappers/shared.py b/mava/wrappers/shared.py
new file mode 100644
index 000000000..459a169e3
--- /dev/null
+++ b/mava/wrappers/shared.py
@@ -0,0 +1,188 @@
+# Copyright 2022 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple, Union
+
+import chex
+import jax.numpy as jnp
+from jumanji import specs
+from jumanji.env import Environment
+from jumanji.types import TimeStep
+from jumanji.wrappers import Wrapper
+
+from mava.types import LogEnvState, Observation, ObservationGlobalState, State
+
+
+class LogWrapper(Wrapper):
+ """Log the episode returns and lengths."""
+
+ def reset(self, key: chex.PRNGKey) -> Tuple[LogEnvState, TimeStep]:
+ """Reset the environment."""
+ state, timestep = self._env.reset(key)
+ state = LogEnvState(state, jnp.float32(0.0), 0, jnp.float32(0.0), 0)
+ return state, timestep
+
+ def step(
+ self,
+ state: LogEnvState,
+ action: chex.Array,
+ ) -> Tuple[LogEnvState, TimeStep]:
+ """Step the environment."""
+ env_state, timestep = self._env.step(state.env_state, action)
+
+ done = timestep.last()
+ not_done = 1 - done
+
+ new_episode_return = state.episode_returns + jnp.mean(timestep.reward)
+ new_episode_length = state.episode_lengths + 1
+ episode_return_info = state.episode_return_info * not_done + new_episode_return * done
+ episode_length_info = state.episode_length_info * not_done + new_episode_length * done
+
+ state = LogEnvState(
+ env_state=env_state,
+ episode_returns=new_episode_return * not_done,
+ episode_lengths=new_episode_length * not_done,
+ episode_return_info=episode_return_info,
+ episode_length_info=episode_length_info,
+ )
+ return state, timestep
+
+
+class AgentIDWrapper(Wrapper):
+ """Add onehot agent IDs to observation."""
+
+ def __init__(self, env: Environment, has_global_state: bool = False):
+ super().__init__(env)
+ self.num_obs_features = self._env.num_obs_features + self._env.num_agents
+ self.has_global_state = has_global_state
+
+ def _add_agent_ids(
+ self, timestep: TimeStep, num_agents: int
+ ) -> Union[Observation, ObservationGlobalState]:
+ agent_ids = jnp.eye(num_agents)
+ new_agents_view = jnp.concatenate([agent_ids, timestep.observation.agents_view], axis=-1)
+
+ if self.has_global_state:
+ # Add the agent IDs to the global state
+ new_global_state = jnp.concatenate(
+ [agent_ids, timestep.observation.global_state], axis=-1
+ )
+
+ return ObservationGlobalState(
+ agents_view=new_agents_view,
+ action_mask=timestep.observation.action_mask,
+ step_count=timestep.observation.step_count,
+ global_state=new_global_state,
+ )
+
+ else:
+ return Observation(
+ agents_view=new_agents_view,
+ action_mask=timestep.observation.action_mask,
+ step_count=timestep.observation.step_count,
+ )
+
+ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
+ """Reset the environment."""
+ state, timestep = self._env.reset(key)
+ timestep.observation = self._add_agent_ids(timestep, self._env.num_agents)
+
+ return state, timestep
+
+ def step(
+ self,
+ state: State,
+ action: chex.Array,
+ ) -> Tuple[State, TimeStep]:
+ """Step the environment."""
+ state, timestep = self._env.step(state, action)
+ timestep.observation = self._add_agent_ids(timestep, self._env.num_agents)
+
+ return state, timestep
+
+ def observation_spec(
+ self,
+ ) -> Union[specs.Spec[Observation], specs.Spec[ObservationGlobalState]]:
+ """Specification of the observation of the `RobotWarehouse` environment."""
+ agents_view = specs.Array(
+ (self._env.num_agents, self.num_obs_features), jnp.int32, "agents_view"
+ )
+ global_state = specs.Array(
+ (
+ self._env.num_agents,
+ self._env.num_obs_features * self._env.num_agents + self._env.num_agents,
+ ),
+ jnp.int32,
+ "global_state",
+ )
+
+ if self.has_global_state:
+ return self._env.observation_spec().replace(
+ agents_view=agents_view,
+ global_state=global_state,
+ )
+ return self._env.observation_spec().replace(
+ agents_view=agents_view,
+ )
+
+
+class GlobalStateWrapper(Wrapper):
+ """Wrapper for adding global state to an environment that follows the mava API.
+
+ The wrapper includes a global environment state to be used by the centralised critic.
+ Note here that since most environments do not have a global state, we create one
+ by concatenating the observations of all agents.
+ """
+
+ def modify_timestep(self, timestep: TimeStep) -> TimeStep[ObservationGlobalState]:
+ global_state = jnp.concatenate(timestep.observation.agents_view, axis=0)
+ global_state = jnp.tile(global_state, (self._env.num_agents, 1))
+
+ observation = ObservationGlobalState(
+ global_state=global_state,
+ agents_view=timestep.observation.agents_view,
+ action_mask=timestep.observation.action_mask,
+ step_count=timestep.observation.step_count,
+ )
+
+ return timestep.replace(observation=observation)
+
+ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
+ """Reset the environment. Updates the step count."""
+ state, timestep = self._env.reset(key)
+ return state, self.modify_timestep(timestep)
+
+ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
+ """Step the environment. Updates the step count."""
+ state, timestep = self._env.step(state, action)
+ return state, self.modify_timestep(timestep)
+
+ def observation_spec(self) -> specs.Spec[ObservationGlobalState]:
+ """Specification of the observation of the `RobotWarehouse` environment."""
+
+ obs_spec = self._env.observation_spec()
+ global_state = specs.Array(
+ (self._env.num_agents, self._env.num_agents * self._env.num_obs_features),
+ jnp.int32,
+ "global_state",
+ )
+
+ return specs.Spec(
+ ObservationGlobalState,
+ "ObservationSpec",
+ agents_view=obs_spec.agents_view,
+ action_mask=obs_spec.action_mask,
+ global_state=global_state,
+ step_count=obs_spec.step_count,
+ )