From 5ae99ed746acd7e072a1f071b5902181f530d148 Mon Sep 17 00:00:00 2001 From: Swain Date: Fri, 12 May 2023 20:24:43 +0800 Subject: [PATCH] feature(nyz): add MAPPO/MASAC task example (#661) * feature(nyz): add MAPPO/MASAC task example * feature(nyz): add example and polish style --- ding/example/mappo.py | 45 +++++++++++++++++ ding/example/masac.py | 49 +++++++++++++++++++ .../framework/middleware/functional/logger.py | 2 +- .../middleware/functional/trainer.py | 9 ++-- ding/framework/middleware/learner.py | 12 +++-- ding/policy/sac.py | 3 ++ ding/torch_utils/network/nn_module.py | 10 ++-- .../config/ptz_simple_spread_masac_config.py | 14 +----- .../envs/petting_zoo_simple_spread_env.py | 3 +- .../test_petting_zoo_simple_spread_env.py | 8 +-- 10 files changed, 125 insertions(+), 30 deletions(-) create mode 100644 ding/example/mappo.py create mode 100644 ding/example/masac.py diff --git a/ding/example/mappo.py b/ding/example/mappo.py new file mode 100644 index 0000000000..53ca5dff3c --- /dev/null +++ b/ding/example/mappo.py @@ -0,0 +1,45 @@ +import gym +from ditk import logging +from ding.model import MAVAC +from ding.policy import PPOPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ + gae_estimator, online_logger, termination_checker +from ding.utils import set_pkg_seed +from dizoo.petting_zoo.config.ptz_simple_spread_mappo_config import main_config, create_config +from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = MAVAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(gae_estimator(cfg, policy.collect_mode)) + task.use(multistep_trainer(policy.learn_mode, log_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000)) + task.use(online_logger(train_show_freq=10)) + task.use(termination_checker(max_env_step=int(1e6))) + task.run() + + +if __name__ == "__main__": + main() diff --git a/ding/example/masac.py b/ding/example/masac.py new file mode 100644 index 0000000000..a268c7366b --- /dev/null +++ b/ding/example/masac.py @@ -0,0 +1,49 @@ +import gym +from ditk import logging +from ding.model import MAQAC +from ding.policy import SACDiscretePolicy +from ding.envs import DingEnvWrapper, BaseEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, CkptSaver, \ + data_pusher, online_logger, termination_checker, eps_greedy_handler +from ding.utils import set_pkg_seed +from dizoo.petting_zoo.config.ptz_simple_spread_masac_config import main_config, create_config +from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_env = BaseEnvManagerV2( + env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = MAQAC(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = SACDiscretePolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use( + StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size) + ) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_, log_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000)) + task.use(online_logger(train_show_freq=10)) + task.use(termination_checker(max_env_step=int(1e6))) + task.run() + + +if __name__ == "__main__": + main() diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index 844e3e8cfb..d7b32e35aa 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -73,7 +73,7 @@ def _logger(ctx: "OnlineRLContext"): else: output = ctx.train_output for k, v in output.items(): - if k in ['priority']: + if k in ['priority', 'td_error_priority']: continue if "[scalars]" in k: new_k = k.split(']')[-1] diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index b71a3bd089..c068de5d23 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -7,13 +7,14 @@ from ding.framework import task, OfflineRLContext, OnlineRLContext -def trainer(cfg: EasyDict, policy: Policy) -> Callable: +def trainer(cfg: EasyDict, policy: Policy, log_freq: int = 100) -> Callable: """ Overview: The middleware that executes a single training process. Arguments: - cfg (:obj:`EasyDict`): Config. - policy (:obj:`Policy`): The policy to be trained in step-by-step mode. + - log_freq (:obj:`int`): The frequency (iteration) of showing log. """ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): @@ -31,7 +32,7 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): return data = ctx.train_data train_output = policy.forward(ctx.train_data) - if ctx.train_iter % cfg.policy.learn.learner.hook.log_show_after_iter == 0: + if ctx.train_iter % log_freq == 0: if isinstance(ctx, OnlineRLContext): logging.info( 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format( @@ -50,13 +51,13 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): return _train -def multistep_trainer(policy: Policy, log_freq: int) -> Callable: +def multistep_trainer(policy: Policy, log_freq: int = 100) -> Callable: """ Overview: The middleware that executes training for a target num of steps. Arguments: - policy (:obj:`Policy`): The policy specialized for multi-step training. - - int (:obj:`int`): The frequency (iteration) of showing log. + - log_freq (:obj:`int`): The frequency (iteration) of showing log. """ last_log_iter = -1 diff --git a/ding/framework/middleware/learner.py b/ding/framework/middleware/learner.py index 91184a7b9b..6fe63ccc79 100644 --- a/ding/framework/middleware/learner.py +++ b/ding/framework/middleware/learner.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from ding.framework import Context, OnlineRLContext + from ding.policy import Policy + from ding.reward_model import BaseRewardModel class OffPolicyLearner: @@ -25,17 +27,19 @@ def __new__(cls, *args, **kwargs): def __init__( self, cfg: EasyDict, - policy, + policy: 'Policy', buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], - reward_model=None + reward_model: Optional['BaseRewardModel'] = None, + log_freq: int = 100, ) -> None: """ Arguments: - cfg (:obj:`EasyDict`): Config. - policy (:obj:`Policy`): The policy to be trained. - - buffer\_ (:obj:`Buffer`): The replay buffer to store the data for training. - - reward_model (:obj:`nn.Module`): Additional reward estimator likes RND, ICM, etc. \ + - buffer (:obj:`Buffer`): The replay buffer to store the data for training. + - reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \ default to None. + - log_freq (:obj:`int`): The frequency (iteration) of showing log. """ self.cfg = cfg self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_)) diff --git a/ding/policy/sac.py b/ding/policy/sac.py index 6f6f2a8491..ca0263305a 100644 --- a/ding/policy/sac.py +++ b/ding/policy/sac.py @@ -317,6 +317,9 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: # target update self._target_model.update(self._learn_model.state_dict()) return { + 'total_loss': loss_dict['total_loss'].item(), + 'policy_loss': loss_dict['policy_loss'].item(), + 'critic_loss': loss_dict['critic_loss'].item(), 'cur_lr_q': self._optimizer_q.defaults['lr'], 'cur_lr_p': self._optimizer_policy.defaults['lr'], 'priority': td_error_per_sample.abs().tolist(), diff --git a/ding/torch_utils/network/nn_module.py b/ding/torch_utils/network/nn_module.py index 732714e6b5..3b4a63fc9f 100644 --- a/ding/torch_utils/network/nn_module.py +++ b/ding/torch_utils/network/nn_module.py @@ -166,16 +166,18 @@ def conv2d_block( ) ) if norm_type is not None: - if norm_type is 'LN': + if norm_type == 'LN': # LN is implemented as GroupNorm with 1 group. block.append(nn.GroupNorm(1, out_channels)) - elif norm_type is 'GN': + elif norm_type == 'GN': block.append(nn.GroupNorm(num_groups_for_gn, out_channels)) elif norm_type in ['BN', 'IN', 'SyncBN']: block.append(build_normalization(norm_type, dim=2)(out_channels)) else: - raise KeyError("Invalid value in norm_type: {}. The valid norm_type are " - "BN, LN, IN, GN and SyncBN.".format(norm_type)) + raise KeyError( + "Invalid value in norm_type: {}. The valid norm_type are " + "BN, LN, IN, GN and SyncBN.".format(norm_type) + ) if activation is not None: block.append(activation) diff --git a/dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py b/dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py index ce000975c1..a3782138aa 100644 --- a/dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py +++ b/dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py @@ -24,25 +24,17 @@ cuda=True, on_policy=False, multi_agent=True, - # priority=True, - # priority_IS_weight=False, - random_collect_size=0, + random_collect_size=5000, model=dict( agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2, global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, action_shape=5, - # SAC concerned twin_critic=True, - actor_head_hidden_size=256, - critic_head_hidden_size=256, ), learn=dict( update_per_collect=50, batch_size=320, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== # learning_rates learning_rate_q=5e-4, learning_rate_policy=5e-4, @@ -51,16 +43,12 @@ discount_factor=0.99, alpha=0.2, auto_alpha=True, - log_space=True, - ignore_down=False, target_entropy=-2, ), collect=dict( n_sample=1600, - unroll_len=1, env_num=collector_env_num, ), - command=dict(), eval=dict( env_num=evaluator_env_num, evaluator=dict(eval_freq=50, ), diff --git a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 4a3fa0c41b..ce4b966fb7 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -180,6 +180,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: obs, rew, done, trunc, info = self._env.step(action) obs_n = self._process_obs(obs) rew_n = np.array([sum([rew[agent] for agent in self._agents])]) + rew_n = rew_n.astype(np.float32) # collide_sum = 0 # for i in range(self._num_agents): # collide_sum += info['n'][i][1] @@ -271,7 +272,7 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa 1 ) # action_mask: All actions are of use(either 1 for discrete or 5 for continuous). Thus all 1. - ret['action_mask'] = np.ones((self._num_agents, *self._action_dim)) + ret['action_mask'] = np.ones((self._num_agents, *self._action_dim)).astype(np.float32) return ret def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa diff --git a/dizoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py b/dizoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py index 9adc71e7f9..22117cf85f 100644 --- a/dizoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py +++ b/dizoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py @@ -1,10 +1,10 @@ +from easydict import EasyDict +import pytest import numpy as np import pettingzoo from ding.utils import import_module -from easydict import EasyDict -import pytest -from dizoo.petting_zoo.envs.petting_zoo_env import PettingZooEnv +from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv @pytest.mark.envtest @@ -39,6 +39,7 @@ def test_agent_obs_only(self): assert timestep.obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2) assert isinstance(timestep.done, bool), timestep.done assert isinstance(timestep.reward, np.ndarray), timestep.reward + assert timestep.reward.dtype == np.float32 print(env.observation_space, env.action_space, env.reward_space) env.close() @@ -80,6 +81,7 @@ def test_dict_obs(self): assert timestep.obs['agent_alone_padding_state'].shape == ( n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 ) + assert timestep.obs['action_mask'].dtype == np.float32 assert isinstance(timestep.done, bool), timestep.done assert isinstance(timestep.reward, np.ndarray), timestep.reward print(env.observation_space, env.action_space, env.reward_space)