Skip to content

Commit

Permalink
feature(nyz): add MAPPO/MASAC task example (#661)
Browse files Browse the repository at this point in the history
* feature(nyz): add MAPPO/MASAC task example

* feature(nyz): add example and polish style
  • Loading branch information
PaParaZz1 authored May 12, 2023
1 parent a8f0ac9 commit 5ae99ed
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 30 deletions.
45 changes: 45 additions & 0 deletions ding/example/mappo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import gym
from ditk import logging
from ding.model import MAVAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
gae_estimator, online_logger, termination_checker
from ding.utils import set_pkg_seed
from dizoo.petting_zoo.config.ptz_simple_spread_mappo_config import main_config, create_config
from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = SubprocessEnvManagerV2(
env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = SubprocessEnvManagerV2(
env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = MAVAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(gae_estimator(cfg, policy.collect_mode))
task.use(multistep_trainer(policy.learn_mode, log_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(online_logger(train_show_freq=10))
task.use(termination_checker(max_env_step=int(1e6)))
task.run()


if __name__ == "__main__":
main()
49 changes: 49 additions & 0 deletions ding/example/masac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import gym
from ditk import logging
from ding.model import MAQAC
from ding.policy import SACDiscretePolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, CkptSaver, \
data_pusher, online_logger, termination_checker, eps_greedy_handler
from ding.utils import set_pkg_seed
from dizoo.petting_zoo.config.ptz_simple_spread_masac_config import main_config, create_config
from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = MAQAC(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = SACDiscretePolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(
StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
)
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_, log_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(online_logger(train_show_freq=10))
task.use(termination_checker(max_env_step=int(1e6)))
task.run()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _logger(ctx: "OnlineRLContext"):
else:
output = ctx.train_output
for k, v in output.items():
if k in ['priority']:
if k in ['priority', 'td_error_priority']:
continue
if "[scalars]" in k:
new_k = k.split(']')[-1]
Expand Down
9 changes: 5 additions & 4 deletions ding/framework/middleware/functional/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from ding.framework import task, OfflineRLContext, OnlineRLContext


def trainer(cfg: EasyDict, policy: Policy) -> Callable:
def trainer(cfg: EasyDict, policy: Policy, log_freq: int = 100) -> Callable:
"""
Overview:
The middleware that executes a single training process.
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be trained in step-by-step mode.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""

def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
Expand All @@ -31,7 +32,7 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
return
data = ctx.train_data
train_output = policy.forward(ctx.train_data)
if ctx.train_iter % cfg.policy.learn.learner.hook.log_show_after_iter == 0:
if ctx.train_iter % log_freq == 0:
if isinstance(ctx, OnlineRLContext):
logging.info(
'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(
Expand All @@ -50,13 +51,13 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
return _train


def multistep_trainer(policy: Policy, log_freq: int) -> Callable:
def multistep_trainer(policy: Policy, log_freq: int = 100) -> Callable:
"""
Overview:
The middleware that executes training for a target num of steps.
Arguments:
- policy (:obj:`Policy`): The policy specialized for multi-step training.
- int (:obj:`int`): The frequency (iteration) of showing log.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""
last_log_iter = -1

Expand Down
12 changes: 8 additions & 4 deletions ding/framework/middleware/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

if TYPE_CHECKING:
from ding.framework import Context, OnlineRLContext
from ding.policy import Policy
from ding.reward_model import BaseRewardModel


class OffPolicyLearner:
Expand All @@ -25,17 +27,19 @@ def __new__(cls, *args, **kwargs):
def __init__(
self,
cfg: EasyDict,
policy,
policy: 'Policy',
buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
reward_model=None
reward_model: Optional['BaseRewardModel'] = None,
log_freq: int = 100,
) -> None:
"""
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be trained.
- buffer\_ (:obj:`Buffer`): The replay buffer to store the data for training.
- reward_model (:obj:`nn.Module`): Additional reward estimator likes RND, ICM, etc. \
- buffer (:obj:`Buffer`): The replay buffer to store the data for training.
- reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \
default to None.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""
self.cfg = cfg
self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_))
Expand Down
3 changes: 3 additions & 0 deletions ding/policy/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
# target update
self._target_model.update(self._learn_model.state_dict())
return {
'total_loss': loss_dict['total_loss'].item(),
'policy_loss': loss_dict['policy_loss'].item(),
'critic_loss': loss_dict['critic_loss'].item(),
'cur_lr_q': self._optimizer_q.defaults['lr'],
'cur_lr_p': self._optimizer_policy.defaults['lr'],
'priority': td_error_per_sample.abs().tolist(),
Expand Down
10 changes: 6 additions & 4 deletions ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,18 @@ def conv2d_block(
)
)
if norm_type is not None:
if norm_type is 'LN':
if norm_type == 'LN':
# LN is implemented as GroupNorm with 1 group.
block.append(nn.GroupNorm(1, out_channels))
elif norm_type is 'GN':
elif norm_type == 'GN':
block.append(nn.GroupNorm(num_groups_for_gn, out_channels))
elif norm_type in ['BN', 'IN', 'SyncBN']:
block.append(build_normalization(norm_type, dim=2)(out_channels))
else:
raise KeyError("Invalid value in norm_type: {}. The valid norm_type are "
"BN, LN, IN, GN and SyncBN.".format(norm_type))
raise KeyError(
"Invalid value in norm_type: {}. The valid norm_type are "
"BN, LN, IN, GN and SyncBN.".format(norm_type)
)

if activation is not None:
block.append(activation)
Expand Down
14 changes: 1 addition & 13 deletions dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,17 @@
cuda=True,
on_policy=False,
multi_agent=True,
# priority=True,
# priority_IS_weight=False,
random_collect_size=0,
random_collect_size=5000,
model=dict(
agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) +
n_landmark * 2 + n_agent * (n_agent - 1) * 2,
action_shape=5,
# SAC concerned
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
learn=dict(
update_per_collect=50,
batch_size=320,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# learning_rates
learning_rate_q=5e-4,
learning_rate_policy=5e-4,
Expand All @@ -51,16 +43,12 @@
discount_factor=0.99,
alpha=0.2,
auto_alpha=True,
log_space=True,
ignore_down=False,
target_entropy=-2,
),
collect=dict(
n_sample=1600,
unroll_len=1,
env_num=collector_env_num,
),
command=dict(),
eval=dict(
env_num=evaluator_env_num,
evaluator=dict(eval_freq=50, ),
Expand Down
3 changes: 2 additions & 1 deletion dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
obs, rew, done, trunc, info = self._env.step(action)
obs_n = self._process_obs(obs)
rew_n = np.array([sum([rew[agent] for agent in self._agents])])
rew_n = rew_n.astype(np.float32)
# collide_sum = 0
# for i in range(self._num_agents):
# collide_sum += info['n'][i][1]
Expand Down Expand Up @@ -271,7 +272,7 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa
1
)
# action_mask: All actions are of use(either 1 for discrete or 5 for continuous). Thus all 1.
ret['action_mask'] = np.ones((self._num_agents, *self._action_dim))
ret['action_mask'] = np.ones((self._num_agents, *self._action_dim)).astype(np.float32)
return ret

def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa
Expand Down
8 changes: 5 additions & 3 deletions dizoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from easydict import EasyDict
import pytest
import numpy as np
import pettingzoo
from ding.utils import import_module
from easydict import EasyDict
import pytest

from dizoo.petting_zoo.envs.petting_zoo_env import PettingZooEnv
from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv


@pytest.mark.envtest
Expand Down Expand Up @@ -39,6 +39,7 @@ def test_agent_obs_only(self):
assert timestep.obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2)
assert isinstance(timestep.done, bool), timestep.done
assert isinstance(timestep.reward, np.ndarray), timestep.reward
assert timestep.reward.dtype == np.float32
print(env.observation_space, env.action_space, env.reward_space)
env.close()

Expand Down Expand Up @@ -80,6 +81,7 @@ def test_dict_obs(self):
assert timestep.obs['agent_alone_padding_state'].shape == (
n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2
)
assert timestep.obs['action_mask'].dtype == np.float32
assert isinstance(timestep.done, bool), timestep.done
assert isinstance(timestep.reward, np.ndarray), timestep.reward
print(env.observation_space, env.action_space, env.reward_space)
Expand Down

0 comments on commit 5ae99ed

Please sign in to comment.