Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: general wrapper framework #948

Merged
merged 12 commits into from
Nov 30, 2023
2 changes: 1 addition & 1 deletion docs/jumanji_rware_comparison.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Please see below for Mava's recurrent and feedforward implementations of IPPO an
<img src="images/rware_results/rec_mappo/small-4ag.png" alt="Mava rec mappo small 4ag" width="30%" style="display:inline-block; margin-right: 10px;"/>
</a>
<br>
<div style="text-align:center; margin-top: 10px;"> Mava recurrent IPPO performance on the <code>tiny-2ag</code>, <code>tiny-4ag</code> and <code>small-4ag</code> RWARE tasks.</div>
<div style="text-align:center; margin-top: 10px;"> Mava recurrent MAPPO performance on the <code>tiny-2ag</code>, <code>tiny-4ag</code> and <code>small-4ag</code> RWARE tasks.</div>
</p>


Expand Down
2 changes: 1 addition & 1 deletion mava/configs/logger/base_logger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use_tf: True # Whether to use tensorboard logging.
base_exp_path: results # Base path for logging.

# --- Neptune logging ---
use_neptune: True
use_neptune: False

kwargs:
neptune_project: Instadeep/Mava
Expand Down
6 changes: 4 additions & 2 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def evaluator_fn(trained_params: FrozenDict, rng: chex.PRNGKey) -> ExperimentOut
# Add dimension to pmap over.
step_rngs = jnp.stack(step_rngs).reshape(eval_batch, -1)

eval_state = EvalState(step_rngs, env_states, timesteps, 0, 0.0)
eval_state = EvalState(
step_rngs, env_states, timesteps, 0, jnp.zeros_like(timesteps.reward)
RuanJohn marked this conversation as resolved.
Show resolved Hide resolved
)
eval_metrics = jax.vmap(
eval_one_episode,
in_axes=(None, EvalState(0, 0, 0, None, None)),
Expand Down Expand Up @@ -231,7 +233,7 @@ def evaluator_fn(
dones=dones,
hstate=init_hstate,
step_count_=0,
return_=0.0,
return_=jnp.zeros_like(timesteps.reward),
)

eval_metrics = jax.vmap(
Expand Down
15 changes: 8 additions & 7 deletions mava/systems/ff_ippo_rware.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from flax.linen.initializers import constant, orthogonal
from jumanji.env import Environment
from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator
from jumanji.types import Observation
from jumanji.wrappers import AutoResetWrapper
from omegaconf import DictConfig, OmegaConf
from optax._src.base import OptState
Expand All @@ -44,12 +43,14 @@
ExperimentOutput,
LearnerFn,
LearnerState,
Observation,
OptStates,
Params,
PPOTransition,
)
from mava.utils.jax import merge_leading_dims
from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareMultiAgentWrapper
from mava.wrappers.jumanji import RwareWrapper
from mava.wrappers.shared import AgentIDWrapper, LogWrapper


class Actor(nn.Module):
Expand Down Expand Up @@ -149,19 +150,19 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

# LOG EPISODE METRICS
done, reward = jax.tree_util.tree_map(
done = jax.tree_util.tree_map(
lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
config["arch"]["num_envs"], -1
),
(timestep.last(), timestep.reward),
timestep.last(),
)
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
}

transition = PPOTransition(
done, action, value, reward, log_prob, last_timestep.observation, info
done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
learner_state = LearnerState(params, opt_states, rng, env_state, timestep)
return learner_state, transition
Expand Down Expand Up @@ -505,14 +506,14 @@ def run_experiment(_config: Dict) -> None:
# Create envs
generator = RandomGenerator(**config["env"]["rware_scenario"]["task_config"])
env = jumanji.make(config["env"]["env_name"], generator=generator)
env = RwareMultiAgentWrapper(env)
env = RwareWrapper(env)
# Add agent id to observation.
if config["system"]["add_agent_id"]:
env = AgentIDWrapper(env)
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env"]["env_name"], generator=generator)
eval_env = RwareMultiAgentWrapper(eval_env)
eval_env = RwareWrapper(eval_env)
if config["system"]["add_agent_id"]:
eval_env = AgentIDWrapper(eval_env)

Expand Down
23 changes: 10 additions & 13 deletions mava/systems/ff_mappo_rware.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from flax.linen.initializers import constant, orthogonal
from jumanji.env import Environment
from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator
from jumanji.types import Observation
from jumanji.wrappers import AutoResetWrapper
from omegaconf import DictConfig, OmegaConf
from optax._src.base import OptState
Expand All @@ -43,17 +42,15 @@
ExperimentOutput,
LearnerFn,
LearnerState,
Observation,
ObservationGlobalState,
OptStates,
Params,
PPOTransition,
)
from mava.utils.jax import merge_leading_dims
from mava.wrappers.jumanji import (
AgentIDWrapper,
LogWrapper,
ObservationGlobalState,
RwareMultiAgentWithGlobalStateWrapper,
)
from mava.wrappers.jumanji import RwareWrapper
from mava.wrappers.shared import AgentIDWrapper, GlobalStateWrapper, LogWrapper


class Actor(nn.Module):
Expand Down Expand Up @@ -90,7 +87,7 @@ class Critic(nn.Module):
"""Critic Network."""

@nn.compact
def __call__(self, observation: Observation) -> chex.Array:
def __call__(self, observation: ObservationGlobalState) -> chex.Array:
"""Forward pass."""

critic_output = nn.Dense(128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(
Expand Down Expand Up @@ -153,19 +150,19 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

# LOG EPISODE METRICS
done, reward = jax.tree_util.tree_map(
done = jax.tree_util.tree_map(
lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
config["arch"]["num_envs"], -1
),
(timestep.last(), timestep.reward),
timestep.last(),
)
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
}

transition = PPOTransition(
done, action, value, reward, log_prob, last_timestep.observation, info
done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
learner_state = LearnerState(params, opt_states, rng, env_state, timestep)
return learner_state, transition
Expand Down Expand Up @@ -514,14 +511,14 @@ def run_experiment(_config: Dict) -> None:
# Create envs
generator = RandomGenerator(**config["env"]["rware_scenario"]["task_config"])
env = jumanji.make(config["env"]["env_name"], generator=generator)
env = RwareMultiAgentWithGlobalStateWrapper(env)
env = GlobalStateWrapper(RwareWrapper(env))
# Add agent id to observation.
if config["system"]["add_agent_id"]:
env = AgentIDWrapper(env=env, has_global_state=True)
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env"]["env_name"], generator=generator)
eval_env = RwareMultiAgentWithGlobalStateWrapper(eval_env)
eval_env = GlobalStateWrapper(RwareWrapper(eval_env))
if config["system"]["add_agent_id"]:
eval_env = AgentIDWrapper(env=eval_env, has_global_state=True)

Expand Down
18 changes: 10 additions & 8 deletions mava/systems/rec_ippo_rware.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@
RecActorApply,
RecCriticApply,
RNNLearnerState,
RNNObservation,
)
from mava.wrappers.jumanji import AgentIDWrapper, LogWrapper, RwareMultiAgentWrapper
from mava.wrappers.jumanji import RwareWrapper
from mava.wrappers.shared import AgentIDWrapper, LogWrapper


class ScannedRNN(nn.Module):
Expand Down Expand Up @@ -89,7 +91,7 @@ class Actor(nn.Module):
def __call__(
self,
policy_hidden_state: chex.Array,
observation_done: Tuple[chex.Array, chex.Array],
observation_done: RNNObservation,
) -> Tuple[chex.Array, distrax.Categorical]:
"""Forward pass."""
observation, done = observation_done
Expand Down Expand Up @@ -128,7 +130,7 @@ class Critic(nn.Module):
def __call__(
self,
critic_hidden_state: Tuple[chex.Array, chex.Array],
observation_done: Tuple[chex.Array, chex.Array],
observation_done: RNNObservation,
) -> Tuple[chex.Array, chex.Array]:
"""Forward pass."""
observation, done = observation_done
Expand Down Expand Up @@ -225,19 +227,19 @@ def _env_step(
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

# log episode return and length
done, reward = jax.tree_util.tree_map(
done = jax.tree_util.tree_map(
lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
config["arch"]["num_envs"], -1
),
(timestep.last(), timestep.reward),
timestep.last(),
)
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
}

transition = PPOTransition(
done, action, value, reward, log_prob, last_timestep.observation, info
done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
hstates = HiddenStates(policy_hidden_state, critic_hidden_state)
learner_state = RNNLearnerState(
Expand Down Expand Up @@ -685,14 +687,14 @@ def run_experiment(_config: Dict) -> None:
# Create envs
generator = RandomGenerator(**config["env"]["rware_scenario"]["task_config"])
env = jumanji.make(config["env"]["env_name"], generator=generator)
env = RwareMultiAgentWrapper(env)
env = RwareWrapper(env)
# Add agent id to observation.
if config["system"]["add_agent_id"]:
env = AgentIDWrapper(env)
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env"]["env_name"], generator=generator)
eval_env = RwareMultiAgentWrapper(eval_env)
eval_env = RwareWrapper(eval_env)
if config["system"]["add_agent_id"]:
eval_env = AgentIDWrapper(eval_env)

Expand Down
24 changes: 11 additions & 13 deletions mava/systems/rec_mappo_rware.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,17 @@
ExperimentOutput,
HiddenStates,
LearnerFn,
ObservationGlobalState,
OptStates,
Params,
PPOTransition,
RecActorApply,
RecCriticApply,
RNNGlobalObservation,
RNNLearnerState,
)
from mava.wrappers.jumanji import (
AgentIDWrapper,
LogWrapper,
ObservationGlobalState,
RwareMultiAgentWithGlobalStateWrapper,
)
from mava.wrappers.jumanji import RwareWrapper
from mava.wrappers.shared import AgentIDWrapper, GlobalStateWrapper, LogWrapper


class ScannedRNN(nn.Module):
Expand Down Expand Up @@ -94,7 +92,7 @@ class Actor(nn.Module):
def __call__(
self,
policy_hidden_state: chex.Array,
observation_done: Tuple[chex.Array, chex.Array],
observation_done: RNNGlobalObservation,
) -> Tuple[chex.Array, distrax.Categorical]:
"""Forward pass."""
observation, done = observation_done
Expand Down Expand Up @@ -133,7 +131,7 @@ class Critic(nn.Module):
def __call__(
self,
critic_hidden_state: Tuple[chex.Array, chex.Array],
observation_done: Tuple[chex.Array, chex.Array],
observation_done: RNNGlobalObservation,
) -> Tuple[chex.Array, chex.Array]:
"""Forward pass."""
observation, done = observation_done
Expand Down Expand Up @@ -224,19 +222,19 @@ def _env_step(
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

# log episode return and length
done, reward = jax.tree_util.tree_map(
done = jax.tree_util.tree_map(
lambda x: jnp.repeat(x, config["system"]["num_agents"]).reshape(
config["arch"]["num_envs"], -1
),
(timestep.last(), timestep.reward),
timestep.last(),
)
info = {
"episode_return": env_state.episode_return_info,
"episode_length": env_state.episode_length_info,
}

transition = PPOTransition(
done, action, value, reward, log_prob, last_timestep.observation, info
done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
hstates = HiddenStates(policy_hidden_state, critic_hidden_state)
learner_state = RNNLearnerState(
Expand Down Expand Up @@ -694,14 +692,14 @@ def run_experiment(_config: Dict) -> None:
# Create envs
generator = RandomGenerator(**config["env"]["rware_scenario"]["task_config"])
env = jumanji.make(config["env"]["env_name"], generator=generator)
env = RwareMultiAgentWithGlobalStateWrapper(env)
env = GlobalStateWrapper(RwareWrapper(env))
# Add agent id to observation.
if config["system"]["add_agent_id"]:
env = AgentIDWrapper(env=env, has_global_state=True)
env = AutoResetWrapper(env)
env = LogWrapper(env)
eval_env = jumanji.make(config["env"]["env_name"], generator=generator)
eval_env = RwareMultiAgentWithGlobalStateWrapper(eval_env)
eval_env = GlobalStateWrapper(RwareWrapper(eval_env))
if config["system"]["add_agent_id"]:
eval_env = AgentIDWrapper(env=eval_env, has_global_state=True)

Expand Down
Loading