diff --git a/README.md b/README.md index 5c24e25ea8..cd1816c80d 100644 --- a/README.md +++ b/README.md @@ -154,20 +154,21 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` | 3 | [box2d/lunarlander](https://github.com/openai/gym/tree/master/gym/envs/box2d) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/box2d/lunarlander/lunarlander.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/lunarlander/envs) | | 4 | [classic_control/cartpole](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/cartpole/cartpole.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/cartpole/envs) | | 5 | [classic_control/pendulum](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/pendulum/pendulum.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/pendulum/envs) | -| 6 | [competitive_rl](https://github.com/cuhkrlcourse/competitive-rl) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/competitive_rl/competitive_rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.classic_control) | -| 7 | [gfootball](https://github.com/google-research/football) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/gfootball/gfootball.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.gfootball/envs) | +| 6 | [competitive_rl](https://github.com/cuhkrlcourse/competitive-rl) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/competitive_rl/competitive_rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.classic_control) | +| 7 | [gfootball](https://github.com/google-research/football) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/gfootball/gfootball.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.gfootball/envs) | | 8 | [minigrid](https://github.com/maximecb/gym-minigrid) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/minigrid/minigrid.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/minigrid/envs) | | 9 | [mujoco](https://github.com/openai/gym/tree/master/gym/envs/mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/majoco/envs) | | 10 | [multiagent_particle](https://github.com/openai/multiagent-particle-envs) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/multiagent_particle/multiagent_particle.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/multiagent_particle/envs) | | 11 | [overcooked](https://github.com/HumanCompatibleAI/overcooked-demo) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/overcooked/overcooked.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/overcooded/envs) | | 12 | [procgen](https://github.com/openai/procgen) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/procgen/coinrun/coinrun.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/procgen) | | 13 | [pybullet](https://github.com/benelot/pybullet-gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/pybullet/pybullet.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pybullet/envs) | -| 14 | [smac](https://github.com/oxwhirl/smac) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/smac/smac.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/smac/envs) | +| 14 | [smac](https://github.com/oxwhirl/smac) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/smac/smac.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/smac/envs) | | 15 | [d4rl](https://github.com/rail-berkeley/d4rl) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | ![ori](dizoo/d4rl/d4rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/d4rl) | -| 16 | league_demo | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/league_demo/league_demo.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/league_demo/envs) | +| 16 | league_demo | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/league_demo/league_demo.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/league_demo/envs) | | 17 | pomdp atari | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pomdp/envs) | | 18 | [bsuite](https://github.com/deepmind/bsuite) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/bsuite/bsuite.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/bsuite/envs) | | 19 | [ImageNet](https://www.image-net.org/) | ![IL](https://img.shields.io/badge/-IL/SL-purple) | ![original](./dizoo/image_classification/imagenet.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/image_classification) | +| 20 | [slime_volleyball](https://github.com/hardmaru/slimevolleygym) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](dizoo/slime_volley/slime_volley.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/slime_volley) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space @@ -181,6 +182,8 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` ![IL](https://img.shields.io/badge/-IL/SL-purple) means Imitation Learning or Supervised Learning Dataset +![selfplay](https://img.shields.io/badge/-selfplay-blue) means environment that allows agent VS agent battle + P.S. some enviroments in Atari, such as **MontezumaRevenge**, are also sparse reward type ## Contribution diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py index 62739af178..3a21f9b516 100644 --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -32,6 +32,15 @@ } +def is_abnormal_timestep(timestep: namedtuple) -> bool: + if isinstance(timestep.info, dict): + return timestep.info.get('abnormal', False) + elif isinstance(timestep.info, list) or isinstance(timestep.info, tuple): + return timestep.info[0].get('abnormal', False) or timestep.info[1].get('abnormal', False) + else: + raise TypeError("invalid env timestep type: {}".format(type(timestep.info))) + + class ShmBuffer(): """ Overview: @@ -452,7 +461,7 @@ def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]: timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get()) for env_id, timestep in timesteps.items(): - if timestep.info.get('abnormal', False): + if is_abnormal_timestep(timestep): self._env_states[env_id] = EnvState.ERROR continue if timestep.done: @@ -497,7 +506,7 @@ def worker_fn( elif cmd in method_name_list: if cmd == 'step': timestep = env.step(*args, **kwargs) - if timestep.info.get('abnormal', False): + if is_abnormal_timestep(timestep): ret = timestep else: if obs_buffer is not None: @@ -554,7 +563,7 @@ def worker_fn_robust( @timeout_wrapper(timeout=step_timeout) def step_fn(*args, **kwargs): timestep = env.step(*args, **kwargs) - if timestep.info.get('abnormal', False): + if is_abnormal_timestep(timestep): ret = timestep else: if obs_buffer is not None: @@ -768,7 +777,7 @@ def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]: for i, (env_id, timestep) in enumerate(timesteps.items()): timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get()) for env_id, timestep in timesteps.items(): - if timestep.info.get('abnormal', False): + if is_abnormal_timestep(timestep): self._env_states[env_id] = EnvState.ERROR continue if timestep.done: diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index a024e5092a..de142a2f43 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -121,8 +121,8 @@ def __init__( else: self.actor = [self.actor_encoder, self.actor_head] self.critic = [self.critic_encoder, self.critic_head] - # for convenience of call some apis(such as: self.critic.parameters()), but may cause - # misunderstanding when print(self) + # Convenient for calling some apis (e.g. self.critic.parameters()), + # but may cause misunderstanding when `print(self)` self.actor = nn.ModuleList(self.actor) self.critic = nn.ModuleList(self.critic) diff --git a/ding/worker/collector/__init__.py b/ding/worker/collector/__init__.py index 86dd466c9d..d72aac1ce9 100644 --- a/ding/worker/collector/__init__.py +++ b/ding/worker/collector/__init__.py @@ -4,6 +4,7 @@ from .sample_serial_collector import SampleSerialCollector from .episode_serial_collector import EpisodeSerialCollector from .battle_episode_serial_collector import BattleEpisodeSerialCollector +from .battle_sample_serial_collector import BattleSampleSerialCollector from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor from .interaction_serial_evaluator import InteractionSerialEvaluator diff --git a/ding/worker/collector/battle_sample_serial_collector.py b/ding/worker/collector/battle_sample_serial_collector.py new file mode 100644 index 0000000000..6bbb9c3788 --- /dev/null +++ b/ding/worker/collector/battle_sample_serial_collector.py @@ -0,0 +1,339 @@ +from typing import Optional, Any, List, Tuple +from collections import namedtuple, deque +from easydict import EasyDict +import numpy as np +import torch + +from ding.envs import BaseEnvManager +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, dicts_to_lists, one_time_warning +from ding.torch_utils import to_tensor, to_ndarray +from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions + + +@SERIAL_COLLECTOR_REGISTRY.register('sample_1v1') +class BattleSampleSerialCollector(ISerialCollector): + """ + Overview: + Sample collector(n_sample) with two policy battle + Interfaces: + __init__, reset, reset_env, reset_policy, collect, close + Property: + envstep + """ + + config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100) + + def __init__( + self, + cfg: EasyDict, + env: BaseEnvManager = None, + policy: List[namedtuple] = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'collector' + ) -> None: + """ + Overview: + Initialization method. + Arguments: + - cfg (:obj:`EasyDict`): Config dict + - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) + - policy (:obj:`List[namedtuple]`): the api namedtuple of collect_mode policy + - tb_logger (:obj:`SummaryWriter`): tensorboard handle + """ + self._exp_name = exp_name + self._instance_name = instance_name + self._collect_print_freq = cfg.collect_print_freq + self._deepcopy_obs = cfg.deepcopy_obs + self._transform_obs = cfg.transform_obs + self._cfg = cfg + self._timer = EasyTimer() + self._end_flag = False + + if tb_logger is not None: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger, self._tb_logger = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + ) + self._traj_len = float("inf") + self.reset(policy, env) + + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset the environment. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + Arguments: + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) + """ + if _env is not None: + self._env = _env + self._env.launch() + self._env_num = self._env.env_num + else: + self._env.reset() + + def reset_policy(self, _policy: Optional[List[namedtuple]] = None) -> None: + """ + Overview: + Reset the policy. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy + """ + assert hasattr(self, '_env'), "please set env first" + if _policy is not None: + assert len(_policy) == 2, "1v1 sample collector needs 2 policy, but found {}".format(len(_policy)) + self._policy = _policy + self._default_n_sample = _policy[0].get_attribute('cfg').collect.get('n_sample', None) + self._unroll_len = _policy[0].get_attribute('unroll_len') + self._on_policy = _policy[0].get_attribute('cfg').on_policy + if self._default_n_sample is not None: + self._traj_len = max( + self._unroll_len, + self._default_n_sample // self._env_num + int(self._default_n_sample % self._env_num != 0) + ) + self._logger.debug( + 'Set default n_sample mode(n_sample({}), env_num({}), traj_len({}))'.format( + self._default_n_sample, self._env_num, self._traj_len + ) + ) + else: + self._traj_len = INF + for p in self._policy: + p.reset() + + def reset(self, _policy: Optional[List[namedtuple]] = None, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset the environment and policy. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) + """ + if _env is not None: + self.reset_env(_env) + if _policy is not None: + self.reset_policy(_policy) + + self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs) + self._policy_output_pool = CachePool('policy_output', self._env_num) + # _traj_buffer is {env_id: {policy_id: TrajBuffer}}, is used to store traj_len pieces of transitions + self._traj_buffer = { + env_id: {policy_id: TrajBuffer(maxlen=self._traj_len) + for policy_id in range(2)} + for env_id in range(self._env_num) + } + self._env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(self._env_num)} + + self._episode_info = [] + self._total_envstep_count = 0 + self._total_episode_count = 0 + self._total_train_sample_count = 0 + self._total_duration = 0 + self._last_train_iter = 0 + self._end_flag = False + + def _reset_stat(self, env_id: int) -> None: + """ + Overview: + Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ + and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ + to get more messages. + Arguments: + - env_id (:obj:`int`): the id where we need to reset the collector's state + """ + for i in range(2): + self._traj_buffer[env_id][i].clear() + self._obs_pool.reset(env_id) + self._policy_output_pool.reset(env_id) + self._env_info[env_id] = {'time': 0., 'step': 0, 'train_sample': 0} + + @property + def envstep(self) -> int: + """ + Overview: + Print the total envstep count. + Return: + - envstep (:obj:`int`): the total envstep count + """ + return self._total_envstep_count + + def close(self) -> None: + """ + Overview: + Close the collector. If end_flag is False, close the environment, flush the tb_logger\ + and close the tb_logger. + """ + if self._end_flag: + return + self._end_flag = True + self._env.close() + self._tb_logger.flush() + self._tb_logger.close() + + def __del__(self) -> None: + """ + Overview: + Execute the close command and close the collector. __del__ is automatically called to \ + destroy the collector instance when the collector finishes its work + """ + self.close() + + def collect(self, + n_sample: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None) -> Tuple[List[Any], List[Any]]: + """ + Overview: + Collect `n_sample` data with policy_kwargs, which is already trained `train_iter` iterations + Arguments: + - n_sample (:obj:`int`): the number of collecting data sample + - train_iter (:obj:`int`): the number of training iteration + - policy_kwargs (:obj:`dict`): the keyword args for policy forward + Returns: + - return_data (:obj:`List`): A list containing training samples. + """ + if n_sample is None: + if self._default_n_sample is None: + raise RuntimeError("Please specify collect n_sample") + else: + n_sample = self._default_n_sample + if n_sample % self._env_num != 0: + one_time_warning( + "Please make sure env_num is divisible by n_sample: {}/{}, which may cause convergence \ + problems in a few algorithms".format(n_sample, self._env_num) + ) + if policy_kwargs is None: + policy_kwargs = {} + collected_sample = [0 for _ in range(2)] + return_data = [[] for _ in range(2)] + return_info = [[] for _ in range(2)] + + while any([c < n_sample for c in collected_sample]): + with self._timer: + # Get current env obs. + obs = self._env.ready_obs + # Policy forward. + self._obs_pool.update(obs) + if self._transform_obs: + obs = to_tensor(obs, dtype=torch.float32) + obs = dicts_to_lists(obs) + policy_output = [p.forward(obs[i], **policy_kwargs) for i, p in enumerate(self._policy)] + self._policy_output_pool.update(policy_output) + # Interact with env. + actions = {} + for policy_output_item in policy_output: + for env_id, output in policy_output_item.items(): + if env_id not in actions: + actions[env_id] = [] + actions[env_id].append(output['action']) + actions = to_ndarray(actions) + timesteps = self._env.step(actions) + + # TODO(nyz) this duration may be inaccurate in async env + interaction_duration = self._timer.value / len(timesteps) + + # TODO(nyz) vectorize this for loop + for env_id, timestep in timesteps.items(): + self._env_info[env_id]['step'] += 1 + self._total_envstep_count += 1 + with self._timer: + for policy_id, policy in enumerate(self._policy): + policy_timestep_data = [d[policy_id] if not isinstance(d, bool) else d for d in timestep] + policy_timestep = type(timestep)(*policy_timestep_data) + transition = self._policy[policy_id].process_transition( + self._obs_pool[env_id][policy_id], self._policy_output_pool[env_id][policy_id], + policy_timestep + ) + transition['collect_iter'] = train_iter + self._traj_buffer[env_id][policy_id].append(transition) + # prepare data + if timestep.done or len(self._traj_buffer[env_id][policy_id]) == self._traj_len: + transitions = to_tensor_transitions(self._traj_buffer[env_id][policy_id]) + train_sample = self._policy[policy_id].get_train_sample(transitions) + return_data[policy_id].extend(train_sample) + self._total_train_sample_count += len(train_sample) + self._env_info[env_id]['train_sample'] += len(train_sample) + collected_sample[policy_id] += len(train_sample) + self._traj_buffer[env_id][policy_id].clear() + + self._env_info[env_id]['time'] += self._timer.value + interaction_duration + + # If env is done, record episode info and reset + if timestep.done: + self._total_episode_count += 1 + info = { + 'reward0': timestep.info[0]['final_eval_reward'], + 'reward1': timestep.info[1]['final_eval_reward'], + 'time': self._env_info[env_id]['time'], + 'step': self._env_info[env_id]['step'], + 'train_sample': self._env_info[env_id]['train_sample'], + } + self._episode_info.append(info) + for i, p in enumerate(self._policy): + p.reset([env_id]) + self._reset_stat(env_id) + for policy_id in range(2): + return_info[policy_id].append(timestep.info[policy_id]) + # log + self._output_log(train_iter) + return_data = [r[:n_sample] for r in return_data] + return return_data, return_info + + def _output_log(self, train_iter: int) -> None: + """ + Overview: + Print the output log information. You can refer to Docs/Best Practice/How to understand\ + training generated folders/Serial mode/log/collector for more details. + Arguments: + - train_iter (:obj:`int`): the number of training iteration. + """ + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: + self._last_train_iter = train_iter + episode_count = len(self._episode_info) + envstep_count = sum([d['step'] for d in self._episode_info]) + duration = sum([d['time'] for d in self._episode_info]) + episode_reward0 = [d['reward0'] for d in self._episode_info] + episode_reward1 = [d['reward1'] for d in self._episode_info] + self._total_duration += duration + info = { + 'episode_count': episode_count, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / episode_count, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_episode_per_sec': episode_count / duration, + 'collect_time': duration, + 'reward0_mean': np.mean(episode_reward0), + 'reward0_std': np.std(episode_reward0), + 'reward0_max': np.max(episode_reward0), + 'reward0_min': np.min(episode_reward0), + 'reward1_mean': np.mean(episode_reward1), + 'reward1_std': np.std(episode_reward1), + 'reward1_max': np.max(episode_reward1), + 'reward1_min': np.min(episode_reward1), + 'total_envstep_count': self._total_envstep_count, + 'total_episode_count': self._total_episode_count, + 'total_duration': self._total_duration, + } + self._episode_info.clear() + self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + for k, v in info.items(): + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if k in ['total_envstep_count']: + continue + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) diff --git a/dizoo/atari/config/serial/pong/pong_ppo_config.py b/dizoo/atari/config/serial/pong/pong_ppo_config.py index d9e5cf65e7..341777f390 100644 --- a/dizoo/atari/config/serial/pong/pong_ppo_config.py +++ b/dizoo/atari/config/serial/pong/pong_ppo_config.py @@ -14,8 +14,8 @@ ), policy=dict( cuda=True, + # (bool) whether to use on-policy training pipeline(on-policy means behaviour policy and training policy are the same) on_policy=False, - # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same) model=dict( obs_shape=[4, 84, 84], action_shape=6, diff --git a/dizoo/league_demo/league_demo_ppo_main.py b/dizoo/league_demo/league_demo_ppo_main.py index 43bf2229aa..596b3565a5 100644 --- a/dizoo/league_demo/league_demo_ppo_main.py +++ b/dizoo/league_demo/league_demo_ppo_main.py @@ -120,7 +120,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)): main_player = league.get_player_by_id(main_key) main_learner = learners[main_key] main_collector = collectors[main_key] - # collect_mode ppo use multimonial sample for selecting action + # collect_mode ppo use multinomial sample for selecting action evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator) evaluator1_cfg.stop_value = cfg.env.stop_value[0] evaluator1 = BattleInteractionSerialEvaluator( diff --git a/dizoo/league_demo/selfplay_demo_ppo_main.py b/dizoo/league_demo/selfplay_demo_ppo_main.py index a78d5387c6..9eb8064dc9 100644 --- a/dizoo/league_demo/selfplay_demo_ppo_main.py +++ b/dizoo/league_demo/selfplay_demo_ppo_main.py @@ -87,7 +87,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)): tb_logger, exp_name=cfg.exp_name ) - # collect_mode ppo use multimonial sample for selecting action + # collect_mode ppo use multinomial sample for selecting action evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator) evaluator1_cfg.stop_value = cfg.env.stop_value[0] evaluator1 = BattleInteractionSerialEvaluator( diff --git a/dizoo/slime_volley/__init__.py b/dizoo/slime_volley/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dizoo/slime_volley/config/__init__.py b/dizoo/slime_volley/config/__init__.py new file mode 100644 index 0000000000..9e0be55660 --- /dev/null +++ b/dizoo/slime_volley/config/__init__.py @@ -0,0 +1 @@ +from .slime_volley_league_ppo_config import slime_volley_league_ppo_config diff --git a/dizoo/slime_volley/config/slime_volley_league_ppo_config.py b/dizoo/slime_volley/config/slime_volley_league_ppo_config.py new file mode 100644 index 0000000000..d48ca085d0 --- /dev/null +++ b/dizoo/slime_volley/config/slime_volley_league_ppo_config.py @@ -0,0 +1,78 @@ +from easydict import EasyDict + +slime_volley_league_ppo_config = dict( + exp_name="slime_volley_league_ppo", + env=dict( + collector_env_num=8, + evaluator_env_num=10, + n_evaluator_episode=100, + stop_value=0, + # Single-agent env for evaluator; Double-agent env for collector. + # Should be assigned True or False in code. + is_evaluator=None, + manager=dict(shared_memory=False, ), + env_id="SlimeVolley-v0", + ), + policy=dict( + cuda=False, + continuous=False, + model=dict( + obs_shape=12, + action_shape=6, + encoder_hidden_size_list=[32, 32], + critic_head_hidden_size=32, + actor_head_hidden_size=32, + share_encoder=False, + ), + learn=dict( + update_per_collect=3, + batch_size=32, + learning_rate=0.00001, + value_weight=0.5, + entropy_weight=0.0, + clip_ratio=0.2, + ), + collect=dict( + n_episode=128, unroll_len=1, discount_factor=1.0, gae_lambda=1.0, collector=dict(get_train_sample=True, ) + ), + other=dict( + league=dict( + player_category=['default'], + path_policy="slime_volley_league_ppo/policy", + active_players=dict( + main_player=1, + main_exploiter=1, + league_exploiter=1, + ), + main_player=dict( + one_phase_step=200, + branch_probs=dict( + pfsp=0.5, + sp=1.0, + ), + strong_win_rate=0.7, + ), + main_exploiter=dict( + one_phase_step=200, + branch_probs=dict(main_players=1.0, ), + strong_win_rate=0.7, + min_valid_win_rate=0.3, + ), + league_exploiter=dict( + one_phase_step=200, + branch_probs=dict(pfsp=1.0, ), + strong_win_rate=0.7, + mutate_prob=0.0, + ), + use_pretrain=False, + use_pretrain_init_historical=False, + payoff=dict( + type='battle', + decay=0.99, + min_win_rate_games=8, + ) + ), + ), + ), +) +slime_volley_league_ppo_config = EasyDict(slime_volley_league_ppo_config) diff --git a/dizoo/slime_volley/config/slime_volley_ppo_config.py b/dizoo/slime_volley/config/slime_volley_ppo_config.py new file mode 100644 index 0000000000..43430e8ae0 --- /dev/null +++ b/dizoo/slime_volley/config/slime_volley_ppo_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict +from ding.entry import serial_pipeline_onpolicy + +slime_volley_ppo_config = dict( + exp_name='slime_volley_ppo', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + agent_vs_agent=False, + stop_value=1000000, + env_id="SlimeVolley-v0", + ), + policy=dict( + cuda=True, + on_policy=True, + continuous=False, + model=dict( + obs_shape=12, + action_shape=6, + encoder_hidden_size_list=[64, 64], + critic_head_hidden_size=64, + actor_head_hidden_size=64, + share_encoder=False, + ), + learn=dict( + epoch_per_collect=5, + batch_size=64, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.0, + clip_ratio=0.2, + ), + collect=dict( + n_sample=4096, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +slime_volley_ppo_config = EasyDict(slime_volley_ppo_config) +main_config = slime_volley_ppo_config +slime_volley_ppo_create_config = dict( + env=dict( + type='slime_volley', + import_names=['dizoo.slime_volley.envs.slime_volley_env'], + ), + env_manager=dict(type='subprocess'), # save replay must use base + policy=dict(type='ppo'), +) +slime_volley_ppo_create_config = EasyDict(slime_volley_ppo_create_config) +create_config = slime_volley_ppo_create_config + +if __name__ == "__main__": + serial_pipeline_onpolicy([main_config, create_config], seed=0) diff --git a/dizoo/slime_volley/entry/slime_volley_selfplay_ppo_main.py b/dizoo/slime_volley/entry/slime_volley_selfplay_ppo_main.py new file mode 100644 index 0000000000..c2bff324ea --- /dev/null +++ b/dizoo/slime_volley/entry/slime_volley_selfplay_ppo_main.py @@ -0,0 +1,83 @@ +import os +import gym +import numpy as np +import copy +import torch +from tensorboardX import SummaryWriter +from functools import partial + +from ding.config import compile_config +from ding.worker import BaseLearner, BattleSampleSerialCollector, NaiveReplayBuffer, InteractionSerialEvaluator +from ding.envs import SyncSubprocessEnvManager +from ding.policy import PPOPolicy +from ding.model import VAC +from ding.utils import set_pkg_seed +from dizoo.slime_volley.envs import SlimeVolleyEnv +from dizoo.slime_volley.config.slime_volley_ppo_config import main_config + + +def main(cfg, seed=0, max_iterations=int(1e10)): + cfg = compile_config( + cfg, + SyncSubprocessEnvManager, + PPOPolicy, + BaseLearner, + BattleSampleSerialCollector, + InteractionSerialEvaluator, + NaiveReplayBuffer, + save_cfg=True + ) + collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num + collector_env_cfg = copy.deepcopy(cfg.env) + collector_env_cfg.agent_vs_agent = True + evaluator_env_cfg = copy.deepcopy(cfg.env) + evaluator_env_cfg.agent_vs_agent = False + collector_env = SyncSubprocessEnvManager( + env_fn=[partial(SlimeVolleyEnv, collector_env_cfg) for _ in range(collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SyncSubprocessEnvManager( + env_fn=[partial(SlimeVolleyEnv, evaluator_env_cfg) for _ in range(evaluator_env_num)], cfg=cfg.env.manager + ) + + collector_env.seed(seed) + evaluator_env.seed(seed, dynamic_seed=False) + set_pkg_seed(seed, use_cuda=cfg.policy.cuda) + + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner( + cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner1' + ) + collector = BattleSampleSerialCollector( + cfg.policy.collect.collector, + collector_env, [policy.collect_mode, policy.collect_mode], + tb_logger, + exp_name=cfg.exp_name + ) + evaluator_cfg = copy.deepcopy(cfg.policy.eval.evaluator) + evaluator_cfg.stop_value = cfg.env.stop_value + evaluator = InteractionSerialEvaluator( + evaluator_cfg, + evaluator_env, + policy.eval_mode, + tb_logger, + exp_name=cfg.exp_name, + instance_name='builtin_ai_evaluator' + ) + + learner.call_hook('before_run') + for _ in range(max_iterations): + if evaluator.should_eval(learner.train_iter): + stop_flag, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop_flag: + break + new_data, _ = collector.collect(train_iter=learner.train_iter) + train_data = new_data[0] + new_data[1] + learner.train(train_data, collector.envstep) + learner.call_hook('after_run') + + +if __name__ == "__main__": + main(main_config) diff --git a/dizoo/slime_volley/envs/__init__.py b/dizoo/slime_volley/envs/__init__.py new file mode 100644 index 0000000000..7fc6e04830 --- /dev/null +++ b/dizoo/slime_volley/envs/__init__.py @@ -0,0 +1 @@ +from .slime_volley_env import SlimeVolleyEnv diff --git a/dizoo/slime_volley/envs/slime_volley_env.py b/dizoo/slime_volley/envs/slime_volley_env.py new file mode 100644 index 0000000000..8a9c68ef82 --- /dev/null +++ b/dizoo/slime_volley/envs/slime_volley_env.py @@ -0,0 +1,189 @@ +from namedlist import namedlist +import numpy as np +import gym +from typing import Any, Union, List, Optional +import copy +import slimevolleygym + +from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo +from ding.envs.common.env_element import EnvElement, EnvElementInfo +from ding.utils import ENV_REGISTRY +from ding.torch_utils import to_tensor, to_ndarray + + +class GymSelfPlayMonitor(gym.wrappers.Monitor): + + def step(self, *args, **kwargs): + self._before_step(*args, **kwargs) + observation, reward, done, info = self.env.step(*args, **kwargs) + done = self._after_step(observation, reward, done, info) + + return observation, reward, done, info + + def _before_step(self, *args, **kwargs): + if not self.enabled: + return + self.stats_recorder.before_step(args[0]) + + +@ENV_REGISTRY.register('slime_volley') +class SlimeVolleyEnv(BaseEnv): + + def __init__(self, cfg) -> None: + self._cfg = cfg + self._init_flag = False + self._replay_path = None + # agent_vs_bot env is single-agent env. obs, action, done, info are all single. + # agent_vs_agent env is double-agent env, obs, action, info are double, done is still single. + self._agent_vs_agent = cfg.agent_vs_agent + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + def step(self, action: Union[np.ndarray, List[np.ndarray]]): + if self._agent_vs_agent: + assert isinstance(action, list) and isinstance(action[0], np.ndarray) + action1, action2 = action[0], action[1] + else: + assert isinstance(action, np.ndarray) + action1, action2 = action, None + assert isinstance(action1, np.ndarray), type(action1) + assert action2 is None or isinstance(action1, np.ndarray), type(action2) + if action1.shape == (1, ): + action1 = action1.squeeze() # 0-dim tensor + if action2 is not None and action2.shape == (1, ): + action2 = action2.squeeze() # 0-dim tensor + action1 = SlimeVolleyEnv._process_action(action1) + action2 = SlimeVolleyEnv._process_action(action2) + obs1, rew, done, info = self._env.step(action1, action2) + obs1 = to_ndarray(obs1).astype(np.float32) + self._final_eval_reward += rew + # info ('ale.lives', 'ale.otherLives', 'otherObs', 'state', 'otherState') + if self._agent_vs_agent: + info = [ + { + 'ale.lives': info['ale.lives'], + 'state': info['state'] + }, { + 'ale.lives': info['ale.otherLives'], + 'state': info['otherState'], + 'obs': info['otherObs'] + } + ] + if done: + info[0]['final_eval_reward'] = self._final_eval_reward + info[1]['final_eval_reward'] = -self._final_eval_reward + else: + if done: + info['final_eval_reward'] = self._final_eval_reward + reward = to_ndarray([rew]).astype(np.float32) + if self._agent_vs_agent: + obs2 = info[1]['obs'] + obs2 = to_ndarray(obs2).astype(np.float32) + observations = np.stack([obs1, obs2], axis=0) + rewards = to_ndarray([rew, -rew]).astype(np.float32) + rewards = rewards[..., np.newaxis] + return BaseEnvTimestep(observations, rewards, done, info) + else: + return BaseEnvTimestep(obs1, reward, done, info) + + def reset(self): + if not self._init_flag: + self._env = gym.make(self._cfg.env_id) + if self._replay_path is not None: + self._env = GymSelfPlayMonitor( + self._env, self._replay_path, video_callable=lambda episode_id: True, force=True + ) + self._init_flag = True + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: + np_seed = 100 * np.random.randint(1, 1000) + self._env.seed(self._seed + np_seed) + elif hasattr(self, '_seed'): + self._env.seed(self._seed) + self._final_eval_reward = 0 + obs = self._env.reset() + obs = to_ndarray(obs).astype(np.float32) + if self._agent_vs_agent: + obs = np.stack([obs, obs], axis=0) + return obs + else: + return obs + + def info(self): + T = EnvElementInfo + return BaseEnvInfo( + agent_num=2, + obs_space=T( + (2, 12) if self._agent_vs_agent else (12, ), + { + 'min': [float("-inf") for _ in range(12)], + 'max': [float("inf") for _ in range(12)], + 'dtype': np.float32, + }, + ), + # [min, max) + # 6 valid actions: + act_space=T( + (1, ), + { + 'min': 0, + 'max': 6, + 'dtype': int, + }, + ), + rew_space=T( + (1, ), + { + 'min': -5.0, + 'max': 5.0, + 'dtype': np.float32, + }, + ), + use_wrappers=None, + ) + + def __repr__(self): + return "DI-engine Slime Volley Env" + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + + @staticmethod + def _process_action(action: np.ndarray, _type: str = "binary") -> np.ndarray: + if action is None: + return None + action = action.item() + # Env receives action in [0, 5] (int type). Can translater into: + # 1) "binary" type: np.array([0, 1, 0]) + # 2) "atari" type: NOOP, LEFT, UPLEFT, UP, UPRIGHT, RIGHT + to_atari_action = { + 0: 0, # NOOP + 1: 4, # LEFT + 2: 7, # UPLEFT + 3: 2, # UP + 4: 6, # UPRIGHT + 5: 3, # RIGHT + } + to_binary_action = { + 0: [0, 0, 0], # NOOP + 1: [1, 0, 0], # LEFT (forward) + 2: [1, 0, 1], # UPLEFT (forward jump) + 3: [0, 0, 1], # UP (jump) + 4: [0, 1, 1], # UPRIGHT (backward jump) + 5: [0, 1, 0], # RIGHT (backward) + } + if _type == "binary": + return to_ndarray(to_binary_action[action]) + elif _type == "atari": + return to_atari_action[action] + else: + raise NotImplementedError diff --git a/dizoo/slime_volley/envs/test_slime_volley_env.py b/dizoo/slime_volley/envs/test_slime_volley_env.py new file mode 100644 index 0000000000..4774c0af11 --- /dev/null +++ b/dizoo/slime_volley/envs/test_slime_volley_env.py @@ -0,0 +1,36 @@ +import pytest +import numpy as np +from easydict import EasyDict + +from dizoo.slime_volley.envs.slime_volley_env import SlimeVolleyEnv + + +@pytest.mark.envtest +class TestSlimeVolley: + + @pytest.mark.parametrize('agent_vs_agent', [True, False]) + def test_slime_volley(self, agent_vs_agent): + total_rew = 0 + env = SlimeVolleyEnv(EasyDict({'env_id': 'SlimeVolley-v0', 'agent_vs_agent': agent_vs_agent})) + # env.enable_save_replay('replay_video') + obs1 = env.reset() + done = False + print(env._env.observation_space) + print('observation is like:', obs1) + done = False + while not done: + if agent_vs_agent: + action1 = np.random.randint(0, 2, (1, )) + action2 = np.random.randint(0, 2, (1, )) + action = [action1, action2] + else: + action = np.random.randint(0, 2, (1, )) + observations, rewards, done, infos = env.step(action) + total_rew += rewards[0] + obs1, obs2 = observations[0], observations[1] + assert obs1.shape == obs2.shape, (obs1.shape, obs2.shape) + if agent_vs_agent: + agent_lives, opponent_lives = infos[0]['ale.lives'], infos[1]['ale.lives'] + if agent_vs_agent: + assert agent_lives == 0 or opponent_lives == 0, (agent_lives, opponent_lives) + print("total reward is:", total_rew) diff --git a/dizoo/slime_volley/slime_volley.gif b/dizoo/slime_volley/slime_volley.gif new file mode 100644 index 0000000000..68326a0b8d Binary files /dev/null and b/dizoo/slime_volley/slime_volley.gif differ diff --git a/setup.py b/setup.py index 6fe5eb60f0..bc66c6b87d 100755 --- a/setup.py +++ b/setup.py @@ -136,6 +136,10 @@ 'whichcraft', 'joblib', ], + + 'slimevolleygym_env': [ + 'slimevolleygym', + ], 'k8s': [ 'kubernetes', ]