From 8f7135b1a7135f12d347ab188a575aa61c037cae Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sat, 14 May 2022 22:57:35 +0800 Subject: [PATCH 01/70] demo(nyz): add naive dp demo --- dizoo/atari/example/atari_dqn_dp.py | 53 +++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 dizoo/atari/example/atari_dqn_dp.py diff --git a/dizoo/atari/example/atari_dqn_dp.py b/dizoo/atari/example/atari_dqn_dp.py new file mode 100644 index 0000000000..b796c70ad6 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dp.py @@ -0,0 +1,53 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.torch_utils import DataParallel +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, nstep_reward_enhancer, termination_checker +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_dp' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + model = DataParallel(model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(termination_checker(max_train_iter=int(1e7))) + task.run() + + +if __name__ == "__main__": + main() From 6733b690468b086f3eacde15012d9538d0d1ecc9 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sun, 15 May 2022 19:47:37 +0800 Subject: [PATCH 02/70] demo(nyz): add naive ddp demo --- ding/utils/pytorch_ddp_dist_helper.py | 4 +- dizoo/atari/example/atari_dqn_ddp.py | 83 +++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 dizoo/atari/example/atari_dqn_ddp.py diff --git a/ding/utils/pytorch_ddp_dist_helper.py b/ding/utils/pytorch_ddp_dist_helper.py index 96847be357..3c9d5881fb 100644 --- a/ding/utils/pytorch_ddp_dist_helper.py +++ b/ding/utils/pytorch_ddp_dist_helper.py @@ -114,7 +114,9 @@ def dist_finalize() -> None: Overview: Finalize distributed training resources """ - dist.destroy_process_group() + # This operation usually hangs out so we ignore it temporally. + # dist.destroy_process_group() + pass class DistContext: diff --git a/dizoo/atari/example/atari_dqn_ddp.py b/dizoo/atari/example/atari_dqn_ddp.py new file mode 100644 index 0000000000..d4445bcb66 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_ddp.py @@ -0,0 +1,83 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.utils import DistContext, get_rank +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, nstep_reward_enhancer +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0): + import numpy as np + import torch + from ding.utils import broadcast + if rank == 0: + if max_env_step is None: + max_env_step = np.inf + if max_train_iter is None: + max_train_iter = np.inf + + def _check(ctx): + if rank == 0: + if ctx.env_step > max_env_step: + finish = torch.ones(1).long().cuda() + elif ctx.train_iter > max_train_iter: + finish = torch.ones(1).long().cuda() + else: + finish = torch.zeros(1).long().cuda() + else: + finish = torch.zeros(1).long().cuda() + broadcast(finish, 0) + task.finish = finish.cpu().bool().item() + + return _check + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_ddp' + main_config.policy.learn.multi_gpu = True + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + with DistContext(): + rank = get_rank() + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + if rank == 0: + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(ddp_termination_checker(max_train_iter=int(1e7), rank=rank)) + task.run() + + +if __name__ == "__main__": + main() From b907c63aed5c4c75e260ae07170626125737fb3f Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sun, 15 May 2022 21:58:49 +0800 Subject: [PATCH 03/70] feature(nyz): add naive tb_logger in new evaluator --- ding/framework/middleware/functional/evaluator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 4d864b3bae..e19561211d 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -156,6 +156,8 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> """ env.seed(cfg.seed, dynamic_seed=False) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(cfg.exp_name) def _evaluate(ctx: "OnlineRLContext"): """ @@ -194,6 +196,8 @@ def _evaluate(ctx: "OnlineRLContext"): eval_monitor.update_reward(env_id, reward) episode_reward = eval_monitor.get_episode_reward() eval_reward = np.mean(episode_reward) + tb_logger.add_scalar('basic/eval_episode_reward_mean-env_step', eval_reward, ctx.env_step) + tb_logger.add_scalar('basic/eval_episode_reward_mean-train_iter', eval_reward, ctx.train_iter) stop_flag = eval_reward >= cfg.env.stop_value and ctx.train_iter > 0 logging.info( 'Evaluation: Train Iter({})\tEnv Step({})\tEval Reward({:.3f})'.format( From 4b833d0629f2bfd5c2ddd19d9c1333526c092575 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Mon, 16 May 2022 12:17:18 +0800 Subject: [PATCH 04/70] Add singleton log writer --- ding/utils/__init__.py | 2 +- ding/utils/log_writer_helper.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 3c92d9667c..addc390dff 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -13,7 +13,7 @@ K8sLauncher from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock from .log_helper import build_logger, pretty_print, LoggerFactory -from .log_writer_helper import DistributedWriter +from .log_writer_helper import DistributedWriter, distributed_writer from .orchestrator_launcher import OrchestratorLauncher from .profiler_helper import Profiler, register_profiler from .registry_factory import registries, POLICY_REGISTRY, ENV_REGISTRY, LEARNER_REGISTRY, COMM_LEARNER_REGISTRY, \ diff --git a/ding/utils/log_writer_helper.py b/ding/utils/log_writer_helper.py index ea11adc41e..1506177317 100644 --- a/ding/utils/log_writer_helper.py +++ b/ding/utils/log_writer_helper.py @@ -103,3 +103,8 @@ def _parallel_fn(self: DistributedWriter, *args, **kwargs): ] for fn_name in ready_to_parallel_fns: setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name))) + +# Examples: +# In main, `distributed_writer.plugin(task.router, is_writer=True)`, +# In middleware, `distributed_writer.record()` +distributed_writer = DistributedWriter() From b0e62389e399eb0b2df94ccb5fa1eac93f45fa15 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Mon, 16 May 2022 18:25:15 +0800 Subject: [PATCH 05/70] Use get_instance on writer --- ding/utils/__init__.py | 2 +- ding/utils/log_writer_helper.py | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index addc390dff..3c92d9667c 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -13,7 +13,7 @@ K8sLauncher from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock from .log_helper import build_logger, pretty_print, LoggerFactory -from .log_writer_helper import DistributedWriter, distributed_writer +from .log_writer_helper import DistributedWriter from .orchestrator_launcher import OrchestratorLauncher from .profiler_helper import Profiler, register_profiler from .registry_factory import registries, POLICY_REGISTRY, ENV_REGISTRY, LEARNER_REGISTRY, COMM_LEARNER_REGISTRY, \ diff --git a/ding/utils/log_writer_helper.py b/ding/utils/log_writer_helper.py index 1506177317..f8b3b28181 100644 --- a/ding/utils/log_writer_helper.py +++ b/ding/utils/log_writer_helper.py @@ -17,6 +17,7 @@ class DistributedWriter(SummaryWriter): The best way is to use it in conjunction with the ``router`` to take advantage of the message \ and event components of the router (see ``writer.plugin``). """ + root = None def __init__(self, *args, **kwargs): self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True @@ -30,6 +31,21 @@ def __init__(self, *args, **kwargs): self._is_writer = False self._lazy_initialized = False + @classmethod + def get_instance(cls, *args, **kwargs) -> "DistributedWriter": + """ + Overview: + Get instance and set the root level instance on the first called. If args and kwargs is none, + this method will return root instance. + """ + if args or kwargs: + ins = cls(*args, **kwargs) + if cls.root is None: + cls.root = ins + return ins + else: + return cls.root + def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter": """ Overview: @@ -103,8 +119,3 @@ def _parallel_fn(self: DistributedWriter, *args, **kwargs): ] for fn_name in ready_to_parallel_fns: setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name))) - -# Examples: -# In main, `distributed_writer.plugin(task.router, is_writer=True)`, -# In middleware, `distributed_writer.record()` -distributed_writer = DistributedWriter() From 7a344e6a16d9ef1ef1ec47bbec0b1f33c1d1d0a4 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Mon, 16 May 2022 20:31:16 +0800 Subject: [PATCH 06/70] feature(nyz): add general logger middleware --- ding/example/sac.py | 6 ++- ding/framework/__init__.py | 6 +++ ding/framework/context.py | 6 ++- ding/framework/middleware/ckpt_handler.py | 2 +- .../middleware/functional/__init__.py | 1 + .../middleware/functional/evaluator.py | 4 -- .../framework/middleware/functional/logger.py | 48 +++++++++++++++++++ ding/policy/sac.py | 5 +- 8 files changed, 65 insertions(+), 13 deletions(-) create mode 100644 ding/framework/middleware/functional/logger.py diff --git a/ding/example/sac.py b/ding/example/sac.py index b47352c609..863479f964 100644 --- a/ding/example/sac.py +++ b/ding/example/sac.py @@ -5,10 +5,10 @@ from ding.envs import DingEnvWrapper, BaseEnvManagerV2 from ding.data import DequeBuffer from ding.config import compile_config -from ding.framework import task +from ding.framework import task, ding_init from ding.framework.context import OnlineRLContext from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \ - CkptSaver, OffPolicyLearner, termination_checker + CkptSaver, OffPolicyLearner, termination_checker, online_logger from ding.utils import set_pkg_seed from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv from dizoo.classic_control.pendulum.config.pendulum_sac_config import main_config, create_config @@ -17,6 +17,7 @@ def main(): logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2(env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(10)], cfg=cfg.env.manager) evaluator_env = BaseEnvManagerV2(env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(5)], cfg=cfg.env.manager) @@ -35,6 +36,7 @@ def main(): task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=100)) task.use(termination_checker(max_train_iter=10000)) + task.use(online_logger()) task.run() diff --git a/ding/framework/__init__.py b/ding/framework/__init__.py index a1059a7ec4..a5b579b539 100644 --- a/ding/framework/__init__.py +++ b/ding/framework/__init__.py @@ -2,3 +2,9 @@ from .task import Task, task from .parallel import Parallel from .event_loop import EventLoop +from easydict import EasyDict +from ding.utils import DistributedWriter + + +def ding_init(cfg: EasyDict): + DistributedWriter.get_instance(cfg.exp_name) diff --git a/ding/framework/context.py b/ding/framework/context.py index 949e0e7d01..70c1143d0d 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -48,13 +48,14 @@ def __init__(self, *args, **kwargs) -> None: self.env_episode = 0 self.train_iter = 0 self.train_data = None + self.train_output = None # collect self.collect_kwargs = {} self.trajectories = None self.episodes = None self.trajectory_end_idx = [] # eval - self.eval_value = -np.inf + self.eval_value = None self.last_eval_iter = -1 self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter') @@ -70,8 +71,9 @@ def __init__(self, *args, **kwargs) -> None: self.train_epoch = 0 self.train_iter = 0 self.train_data = None + self.train_output = None # eval - self.eval_value = -np.inf + self.eval_value = None self.last_eval_iter = -1 self.keep('train_iter', 'last_eval_iter') diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index 89772e66ee..4657eff8d9 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -54,7 +54,7 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: self.last_save_iter = ctx.train_iter # best eval reward so far - if ctx.eval_value > self.max_eval_value: + if ctx.eval_value and ctx.eval_value > self.max_eval_value: save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) self.max_eval_value = ctx.eval_value diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index defdc3f4c4..463ee84284 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -5,6 +5,7 @@ from .evaluator import interaction_evaluator from .termination_checker import termination_checker from .pace_controller import pace_controller +from .logger import online_logger # algorithm from .explorer import eps_greedy_handler, eps_greedy_masker diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index e19561211d..4d864b3bae 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -156,8 +156,6 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> """ env.seed(cfg.seed, dynamic_seed=False) - from tensorboardX import SummaryWriter - tb_logger = SummaryWriter(cfg.exp_name) def _evaluate(ctx: "OnlineRLContext"): """ @@ -196,8 +194,6 @@ def _evaluate(ctx: "OnlineRLContext"): eval_monitor.update_reward(env_id, reward) episode_reward = eval_monitor.get_episode_reward() eval_reward = np.mean(episode_reward) - tb_logger.add_scalar('basic/eval_episode_reward_mean-env_step', eval_reward, ctx.env_step) - tb_logger.add_scalar('basic/eval_episode_reward_mean-train_iter', eval_reward, ctx.train_iter) stop_flag = eval_reward >= cfg.env.stop_value and ctx.train_iter > 0 logging.info( 'Evaluation: Train Iter({})\tEnv Step({})\tEval Reward({:.3f})'.format( diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py new file mode 100644 index 0000000000..bec65cbd1f --- /dev/null +++ b/ding/framework/middleware/functional/logger.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING, Callable, Dict, List +from collections import deque +from ding.utils import DistributedWriter + +if TYPE_CHECKING: + from ding.framework import OnlineRLContext + + +def online_logger(record_train_iter: bool = False) -> Callable: + writer = DistributedWriter.get_instance() + + def _logger(ctx: "OnlineRLContext"): + if ctx.eval_value is not None: + if record_train_iter: + writer.add_scalar('basic/eval_episode_reward_mean-env_step', ctx.eval_value, ctx.env_step) + writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) + else: + writer.add_scalar('basic/eval_episode_reward_mean', ctx.eval_value, ctx.env_step) + if ctx.train_output is not None: + if isinstance(ctx.train_output, deque): + output = ctx.train_output.pop() # only use latest output + else: + output = ctx.train_output + # TODO(nyz) ppo train log case + if isinstance(output, List): + raise NotImplementedError + for k, v in output.items(): + if k in ['priority']: + continue + if "[scalars]" in k: + new_k = k.split(']')[-1] + raise NotImplementedError + elif "[histogram]" in k: + new_k = k.split(']')[-1] + writer.add_histogram(new_k, v, ctx.env_step) + if record_train_iter: + writer.add_histogram(new_k, v, ctx.train_iter) + else: + if record_train_iter: + writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) + writer.add_scalar('basic/train_{}-env_step'.format(k), v, ctx.env_step) + else: + writer.add_scalar('basic/train_{}'.format(k), v, ctx.env_step) + + return _logger + + +# TODO offline logger diff --git a/ding/policy/sac.py b/ding/policy/sac.py index 11306f30f7..2dc0ef1791 100644 --- a/ding/policy/sac.py +++ b/ding/policy/sac.py @@ -357,8 +357,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: torch.zeros_like(self._alpha)).requires_grad_() loss_dict['total_loss'] = sum(loss_dict.values()) - info_dict = {} - # ============= # after update # ============= @@ -375,8 +373,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'q_value_2': target_q_value[1].detach().mean().item(), 'target_value': target_value.detach().mean().item(), 'entropy': entropy.item(), - **info_dict, - **loss_dict + #**loss_dict } def _state_dict_learn(self) -> Dict[str, Any]: From ae37da243c1c30c0aed9fab3694a97655e0379fd Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 17 May 2022 10:55:19 +0800 Subject: [PATCH 07/70] feature(nyz): add soft update in DQN target network --- ding/policy/dqn.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index 58122bcaf1..1ed1f68d05 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -143,12 +143,22 @@ def _init_learn(self) -> None: # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.learn.target_update_freq} - ) + if 'target_update_freq' in self._cfg.learn: + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.learn.target_update_freq} + ) + elif 'target_theta' in self._cfg.learn: + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.learn.target_theta} + ) + else: + raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta") self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') self._learn_model.reset() self._target_model.reset() @@ -189,7 +199,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: # Target q value with torch.no_grad(): target_q_value = self._target_model.forward(data['next_obs'])['logit'] - # Max q value action (main model) + # Max q value action (main model), i.e. Double DQN target_q_action = self._learn_model.forward(data['next_obs'])['action'] data_n = q_nstep_td_data( From dbee60ac47ec27cad99eb37ad252a8db39bf0c8b Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 17 May 2022 11:03:00 +0800 Subject: [PATCH 08/70] fix(nyz): fix termination env_step bug and eval task.finish broadcast bug --- dizoo/atari/example/atari_dqn.py | 2 +- dizoo/atari/example/atari_dqn_ddp.py | 4 ++-- dizoo/atari/example/atari_dqn_dp.py | 2 +- dizoo/mujoco/example/mujoco_sac.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dizoo/atari/example/atari_dqn.py b/dizoo/atari/example/atari_dqn.py index 4f4be49ecb..1ac9fdc0c8 100644 --- a/dizoo/atari/example/atari_dqn.py +++ b/dizoo/atari/example/atari_dqn.py @@ -42,7 +42,7 @@ def main(): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=1000)) - task.use(termination_checker(max_train_iter=int(1e7))) + task.use(termination_checker(max_env_step=int(1e7))) task.run() diff --git a/dizoo/atari/example/atari_dqn_ddp.py b/dizoo/atari/example/atari_dqn_ddp.py index d4445bcb66..bc07da3bfb 100644 --- a/dizoo/atari/example/atari_dqn_ddp.py +++ b/dizoo/atari/example/atari_dqn_ddp.py @@ -32,7 +32,7 @@ def _check(ctx): elif ctx.train_iter > max_train_iter: finish = torch.ones(1).long().cuda() else: - finish = torch.zeros(1).long().cuda() + finish = torch.LongTensor([task.finish]).cuda() else: finish = torch.zeros(1).long().cuda() broadcast(finish, 0) @@ -75,7 +75,7 @@ def main(): task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) if rank == 0: task.use(CkptSaver(cfg, policy, train_freq=1000)) - task.use(ddp_termination_checker(max_train_iter=int(1e7), rank=rank)) + task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank)) task.run() diff --git a/dizoo/atari/example/atari_dqn_dp.py b/dizoo/atari/example/atari_dqn_dp.py index b796c70ad6..cea9618061 100644 --- a/dizoo/atari/example/atari_dqn_dp.py +++ b/dizoo/atari/example/atari_dqn_dp.py @@ -45,7 +45,7 @@ def main(): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=1000)) - task.use(termination_checker(max_train_iter=int(1e7))) + task.use(termination_checker(max_env_step=int(1e7))) task.run() diff --git a/dizoo/mujoco/example/mujoco_sac.py b/dizoo/mujoco/example/mujoco_sac.py index 0349d8d161..471e4c8f29 100644 --- a/dizoo/mujoco/example/mujoco_sac.py +++ b/dizoo/mujoco/example/mujoco_sac.py @@ -37,7 +37,7 @@ def main(): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=500)) - task.use(termination_checker(max_train_iter=int(3e6))) + task.use(termination_checker(max_env_step=int(3e6))) task.run() From e7c9d965123b0c1cd4977b06e84810c92e615845 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Tue, 26 Apr 2022 15:52:02 +0800 Subject: [PATCH 09/70] Support distributed dqn --- ding/example/dqn_dist.py | 97 +++++++++++++++++++ ding/framework/context.py | 5 +- ding/framework/middleware/ckpt_handler.py | 6 +- .../middleware/functional/__init__.py | 1 + .../middleware/functional/evaluator.py | 5 +- .../middleware/functional/exchanger.py | 73 ++++++++++++++ ding/framework/parallel.py | 5 +- ding/utils/data/structure/__init__.py | 1 + ding/utils/data/structure/lifo_deque.py | 12 +++ 9 files changed, 198 insertions(+), 7 deletions(-) create mode 100644 ding/example/dqn_dist.py create mode 100644 ding/framework/middleware/functional/exchanger.py create mode 100644 ding/utils/data/structure/lifo_deque.py diff --git a/ding/example/dqn_dist.py b/ding/example/dqn_dist.py new file mode 100644 index 0000000000..72516a62c6 --- /dev/null +++ b/ding/example/dqn_dist.py @@ -0,0 +1,97 @@ +""" +The distributed version of DQN pipeline. +With N workers = 1 learner + 1 evaluator + (N-2) actors + +# First Example —— Execute on one machine with multi processes. +Execute 4 processes with 1 learner + 1 evaluator + 2 actors +Remember to keep them connected by mesh to ensure that they can exchange information with each other. + +> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 4 --topology mesh + +# Second Example —— Execute on multiple machines. +1. Execute 1 learner + 1 evaluator on one machine. +> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology mesh --node-ids 0 --ports 50515 + +2. Execute 2 collectors on another machine. (Suppose the ip of the first machine is 127.0.0.1). + Here we use `alone` topology instead of `mesh` because the collectors do not need communicate with each other. + Remember the `node_ids` cannot be duplicated with the learner, evaluator processes. + And remember to set the `ports` (should not conflict with others) and `attach_to` parameters. +> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology alone --node-ids 2 \ + --ports 50517 --attach-to tcp://127.0.0.1:50515,tcp://127.0.0.1:50516 +""" +import gym +import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, BaseEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger +from ding.utils import set_pkg_seed +from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + # cfg.env.stop_value = 99999999 # Don't stop + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + if task.router.node_id == 0: # Learner + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=100)) + + elif task.router.node_id == 1: # Evaluator + logging.info("Evaluator running on node {}".format(task.router.node_id)) + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager + ) + task.use(context_exchanger(recv_keys=["train_iter"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + + else: # Collectors + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + + task.run() + + +if __name__ == "__main__": + main() diff --git a/ding/framework/context.py b/ding/framework/context.py index 70c1143d0d..3ad27a3b83 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -7,6 +7,7 @@ class Context(dict): Context is an object that pass contextual data between middlewares, whose life cycle is only one training iteration. It is a dict that reflect itself, so you can set any properties as you wish. + Note that the initial value of the property must be equal to False. """ def __init__(self, *args, **kwargs) -> None: @@ -56,7 +57,7 @@ def __init__(self, *args, **kwargs) -> None: self.trajectory_end_idx = [] # eval self.eval_value = None - self.last_eval_iter = -1 + self.last_eval_iter = None self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter') @@ -74,6 +75,6 @@ def __init__(self, *args, **kwargs) -> None: self.train_output = None # eval self.eval_value = None - self.last_eval_iter = -1 + self.last_eval_iter = None self.keep('train_iter', 'last_eval_iter') diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index 4657eff8d9..382886e26d 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -17,7 +17,7 @@ class CkptSaver: The class used to save checkpoint data. """ - def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = None): + def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = None, save_finish: bool = True): """ Overview: Initialize the `CkptSaver`. @@ -25,6 +25,7 @@ def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = No - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.exp_name`. - policy (:obj:`Policy`): Policy used to save the checkpoint. - train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data. + - save_finish (:obj:`int`): Whether save final ckpt when ``task.finish = True``. """ self.policy = policy self.train_freq = train_freq @@ -33,6 +34,7 @@ def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = No os.mkdir(self.prefix) self.last_save_iter = 0 self.max_eval_value = -np.inf + self.save_finish = save_finish def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: """ @@ -59,5 +61,5 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: self.max_eval_value = ctx.eval_value # finish - if task.finish: + if task.finish and self.save_finish: save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 463ee84284..743c61f8e3 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -6,6 +6,7 @@ from .termination_checker import termination_checker from .pace_controller import pace_controller from .logger import online_logger +from .exchanger import context_exchanger, model_exchanger # algorithm from .explorer import eps_greedy_handler, eps_greedy_masker diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 4d864b3bae..b1c9599953 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -169,7 +169,8 @@ def _evaluate(ctx: "OnlineRLContext"): - eval_value (:obj:`float`): The average reward in the current evaluation. """ - if ctx.last_eval_iter != -1 and \ + # evaluation will be executed if the task begins or enough train_iter after last evaluation + if ctx.last_eval_iter is not None and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): return @@ -214,7 +215,7 @@ def metric_evaluator(cfg: EasyDict, policy: Policy, dataset: Dataset, metric: IM def _evaluate(ctx: "Context"): # evaluation will be executed if the task begins or enough train_iter after last evaluation - if ctx.last_eval_iter != -1 and \ + if ctx.last_eval_iter is not None and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): return diff --git a/ding/framework/middleware/functional/exchanger.py b/ding/framework/middleware/functional/exchanger.py new file mode 100644 index 0000000000..8710855f40 --- /dev/null +++ b/ding/framework/middleware/functional/exchanger.py @@ -0,0 +1,73 @@ +from time import sleep +from typing import TYPE_CHECKING, List, Dict +from ding.framework import task +from ding.utils.data.structure.lifo_deque import LifoDeque +if TYPE_CHECKING: + from ding.framework.context import Context + from torch.nn import Module + + +def context_exchanger(send_keys: List[str] = None, recv_keys: List[str] = None, skip_n_iter: int = 0): + """ + Overview: + Send data from context in the backward stage. + Buffer received data and wait if not get any data. + Arguments: + - send_keys (:obj:`List[str]`): Keys need to be sent. + - recv_keys (:obj:`List[str]`): Keys need to be received. + - skip_n_iter (:obj:`int`): Whether to skip the first N round of waiting, + e.g. collecting data without waiting for a new model in the first N round, + while training a model that needs to wait for data in the first round. + """ + event_name = "context_exchanger" + + bufferd_payloads = LifoDeque(maxsize=100) + task.on(event_name, lambda payload: bufferd_payloads.put(payload)) + + def _context_exchanger(ctx: "Context"): + if recv_keys: + if ctx.total_step >= skip_n_iter: + payload: Dict = bufferd_payloads.get() + for key in recv_keys: + value = payload.get(key) + if value: + ctx[key] = value + + if send_keys: + yield + payload = {} + for key in send_keys: + payload[key] = ctx.get(key) + if payload: + task.emit(event_name, payload, only_remote=True) + + return _context_exchanger + + +def model_exchanger(model: "Module", is_learner: bool = False): + """ + Overview: + Exchange model between processes, only the learner will send the model, + otherwise the model will only be received. + If you are using a shared model on a single host, there is no need to use this middleware. + Arguments: + - model (:obj:`torch.nn.Module`): Pytorch module. + - is_learner (:obj:`bool`): Whether use this middleware as learner or not. + """ + event_name = "model_exchanger" + bufferd_state_dict = LifoDeque(maxsize=1) + + if not is_learner: + task.on(event_name, lambda state_dict: bufferd_state_dict.put(state_dict)) + + def _model_exchanger(ctx: "Context"): + if not is_learner: + if ctx.total_step != 0: # Skip first iteration + state_dict = bufferd_state_dict.get() + model.load_state_dict(state_dict) + + if is_learner: + yield + task.emit(event_name, model.state_dict(), only_remote=True) + + return _model_exchanger diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 469ae7e77f..6db8808600 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -34,6 +34,7 @@ def __init__(self) -> None: def _run( self, node_id: int, + n_parallel_workers: int, labels: Optional[Set[str]] = None, auto_recover: bool = False, max_retries: int = float("inf"), @@ -41,6 +42,7 @@ def _run( **kwargs ) -> None: self.node_id = node_id + self.n_parallel_workers = n_parallel_workers self.labels = labels or set() self.auto_recover = auto_recover self.max_retries = max_retries @@ -156,6 +158,7 @@ def topology_network(i: int) -> List[str]: "node_id": candidate_node_ids[i], "listen_to": nodes[i], "attach_to": topology_network(i), + "n_parallel_workers": n_parallel_workers, } runner_params.append(runner_kwargs) @@ -166,7 +169,7 @@ def _redis_args_parser(cls, n_parallel_workers: int, node_ids: Optional[Union[Li runner_params = [] candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0) for i in range(n_parallel_workers): - runner_kwargs = {**kwargs, "node_id": candidate_node_ids[i]} + runner_kwargs = {**kwargs, "n_parallel_workers": n_parallel_workers, "node_id": candidate_node_ids[i]} runner_params.append(runner_kwargs) return runner_params diff --git a/ding/utils/data/structure/__init__.py b/ding/utils/data/structure/__init__.py index 9e8011f9d4..3cc58828a6 100644 --- a/ding/utils/data/structure/__init__.py +++ b/ding/utils/data/structure/__init__.py @@ -1 +1,2 @@ from .cache import Cache +from .lifo_deque import LifoDeque diff --git a/ding/utils/data/structure/lifo_deque.py b/ding/utils/data/structure/lifo_deque.py new file mode 100644 index 0000000000..00d9221e5c --- /dev/null +++ b/ding/utils/data/structure/lifo_deque.py @@ -0,0 +1,12 @@ +from queue import LifoQueue +from collections import deque + + +class LifoDeque(LifoQueue): + """ + Like LifoQueue, but automatically replaces the oldest data when the queue is full. + """ + + def _init(self, maxsize): + self.maxsize = maxsize + 1 + self.queue = deque(maxlen=maxsize) From 472abb796d08f2e17ae4572ca366c12145661dbe Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Tue, 26 Apr 2022 16:06:56 +0800 Subject: [PATCH 10/70] Add more desc (ci skip) --- ding/example/dqn_dist.py | 10 ++++++++-- ding/framework/message_queue/nng.py | 1 + ding/framework/parallel.py | 1 - 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ding/example/dqn_dist.py b/ding/example/dqn_dist.py index 72516a62c6..8d08d3be6d 100644 --- a/ding/example/dqn_dist.py +++ b/ding/example/dqn_dist.py @@ -10,14 +10,20 @@ # Second Example —— Execute on multiple machines. 1. Execute 1 learner + 1 evaluator on one machine. + > ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology mesh --node-ids 0 --ports 50515 2. Execute 2 collectors on another machine. (Suppose the ip of the first machine is 127.0.0.1). Here we use `alone` topology instead of `mesh` because the collectors do not need communicate with each other. Remember the `node_ids` cannot be duplicated with the learner, evaluator processes. And remember to set the `ports` (should not conflict with others) and `attach_to` parameters. + The value of the `attach_to` parameter should be obtained from the log of the + process started earlier (e.g. 'NNG listen on tcp://10.0.0.4:50515'). + > ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology alone --node-ids 2 \ - --ports 50517 --attach-to tcp://127.0.0.1:50515,tcp://127.0.0.1:50516 + --ports 50517 --attach-to tcp://10.0.0.4:50515,tcp://127.0.0.1:50516 + +3. You can repeat step 2 to start more collectors on other machines. """ import gym import logging @@ -68,7 +74,7 @@ def main(): env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) - task.use(context_exchanger(recv_keys=["train_iter"], skip_n_iter=1)) + task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) task.use(model_exchanger(model, is_learner=False)) task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(CkptSaver(cfg, policy, save_finish=False)) diff --git a/ding/framework/message_queue/nng.py b/ding/framework/message_queue/nng.py index feab6473a1..fd48c56585 100644 --- a/ding/framework/message_queue/nng.py +++ b/ding/framework/message_queue/nng.py @@ -30,6 +30,7 @@ def listen(self) -> None: sleep(0.1) # Wait for peers to bind for contact in self.attach_to: sock.dial(contact) + logging.info("NNG listen on {}, attach to {}".format(self.listen_to, self.attach_to)) def publish(self, topic: str, data: bytes) -> None: if not self._finished: diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 6db8808600..b61149debf 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -130,7 +130,6 @@ def _nng_args_parser( ) -> Dict[str, dict]: attach_to = attach_to or [] nodes = cls.get_node_addrs(n_parallel_workers, protocol=protocol, address=address, ports=ports) - logging.info("Bind subprocesses on these addresses: {}".format(nodes)) def cleanup_nodes(): for node in nodes: From f80f0475c54a6d144d06bbe52f2fe0faf2af63d5 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Tue, 26 Apr 2022 15:52:02 +0800 Subject: [PATCH 11/70] Support distributed dqn Add more desc (ci skip) Add timeout on model exchanger --- ding/example/dqn_dist.py | 103 ++++++++++++++++++ ding/framework/context.py | 7 +- ding/framework/message_queue/nng.py | 1 + ding/framework/middleware/ckpt_handler.py | 8 +- .../middleware/functional/__init__.py | 1 + .../middleware/functional/evaluator.py | 5 +- .../middleware/functional/exchanger.py | 77 +++++++++++++ ding/framework/parallel.py | 6 +- ding/utils/data/structure/__init__.py | 1 + ding/utils/data/structure/lifo_deque.py | 12 ++ 10 files changed, 212 insertions(+), 9 deletions(-) create mode 100644 ding/example/dqn_dist.py create mode 100644 ding/framework/middleware/functional/exchanger.py create mode 100644 ding/utils/data/structure/lifo_deque.py diff --git a/ding/example/dqn_dist.py b/ding/example/dqn_dist.py new file mode 100644 index 0000000000..261d211d70 --- /dev/null +++ b/ding/example/dqn_dist.py @@ -0,0 +1,103 @@ +""" +The distributed version of DQN pipeline. +With N workers = 1 learner + 1 evaluator + (N-2) actors + +# First Example —— Execute on one machine with multi processes. +Execute 4 processes with 1 learner + 1 evaluator + 2 actors +Remember to keep them connected by mesh to ensure that they can exchange information with each other. + +> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 4 --topology mesh + +# Second Example —— Execute on multiple machines. +1. Execute 1 learner + 1 evaluator on one machine. + +> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology mesh --node-ids 0 --ports 50515 + +2. Execute 2 collectors on another machine. (Suppose the ip of the first machine is 127.0.0.1). + Here we use `alone` topology instead of `mesh` because the collectors do not need communicate with each other. + Remember the `node_ids` cannot be duplicated with the learner, evaluator processes. + And remember to set the `ports` (should not conflict with others) and `attach_to` parameters. + The value of the `attach_to` parameter should be obtained from the log of the + process started earlier (e.g. 'NNG listen on tcp://10.0.0.4:50515'). + +> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology alone --node-ids 2 \ + --ports 50517 --attach-to tcp://10.0.0.4:50515,tcp://127.0.0.1:50516 + +3. You can repeat step 2 to start more collectors on other machines. +""" +import gym +import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, BaseEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger +from ding.utils import set_pkg_seed +from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + # cfg.env.stop_value = 99999999 # Don't stop + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + if task.router.node_id == 0: # Learner + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=100)) + + elif task.router.node_id == 1: # Evaluator + logging.info("Evaluator running on node {}".format(task.router.node_id)) + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager + ) + task.use(context_exchanger(recv_keys=["train_iter"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + + else: # Collectors + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + + task.run() + + +if __name__ == "__main__": + main() diff --git a/ding/framework/context.py b/ding/framework/context.py index 70c1143d0d..6b12ef9e70 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -7,6 +7,7 @@ class Context(dict): Context is an object that pass contextual data between middlewares, whose life cycle is only one training iteration. It is a dict that reflect itself, so you can set any properties as you wish. + Note that the initial value of the property must be equal to False. """ def __init__(self, *args, **kwargs) -> None: @@ -54,9 +55,11 @@ def __init__(self, *args, **kwargs) -> None: self.trajectories = None self.episodes = None self.trajectory_end_idx = [] + self.action = [] + self.inference_output = None # eval self.eval_value = None - self.last_eval_iter = -1 + self.last_eval_iter = None self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter') @@ -74,6 +77,6 @@ def __init__(self, *args, **kwargs) -> None: self.train_output = None # eval self.eval_value = None - self.last_eval_iter = -1 + self.last_eval_iter = None self.keep('train_iter', 'last_eval_iter') diff --git a/ding/framework/message_queue/nng.py b/ding/framework/message_queue/nng.py index feab6473a1..fd48c56585 100644 --- a/ding/framework/message_queue/nng.py +++ b/ding/framework/message_queue/nng.py @@ -30,6 +30,7 @@ def listen(self) -> None: sleep(0.1) # Wait for peers to bind for contact in self.attach_to: sock.dial(contact) + logging.info("NNG listen on {}, attach to {}".format(self.listen_to, self.attach_to)) def publish(self, topic: str, data: bytes) -> None: if not self._finished: diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index 4657eff8d9..1bfb7e23cf 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -17,7 +17,7 @@ class CkptSaver: The class used to save checkpoint data. """ - def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = None): + def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = None, save_finish: bool = True): """ Overview: Initialize the `CkptSaver`. @@ -25,6 +25,7 @@ def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = No - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.exp_name`. - policy (:obj:`Policy`): Policy used to save the checkpoint. - train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data. + - save_finish (:obj:`bool`): Whether save the ckpt at finish. """ self.policy = policy self.train_freq = train_freq @@ -33,6 +34,7 @@ def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = No os.mkdir(self.prefix) self.last_save_iter = 0 self.max_eval_value = -np.inf + self.save_finish = save_finish def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: """ @@ -54,10 +56,10 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: self.last_save_iter = ctx.train_iter # best eval reward so far - if ctx.eval_value and ctx.eval_value > self.max_eval_value: + if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value: save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) self.max_eval_value = ctx.eval_value # finish - if task.finish: + if task.finish and self.save_finish: save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 463ee84284..743c61f8e3 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -6,6 +6,7 @@ from .termination_checker import termination_checker from .pace_controller import pace_controller from .logger import online_logger +from .exchanger import context_exchanger, model_exchanger # algorithm from .explorer import eps_greedy_handler, eps_greedy_masker diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 4d864b3bae..b1c9599953 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -169,7 +169,8 @@ def _evaluate(ctx: "OnlineRLContext"): - eval_value (:obj:`float`): The average reward in the current evaluation. """ - if ctx.last_eval_iter != -1 and \ + # evaluation will be executed if the task begins or enough train_iter after last evaluation + if ctx.last_eval_iter is not None and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): return @@ -214,7 +215,7 @@ def metric_evaluator(cfg: EasyDict, policy: Policy, dataset: Dataset, metric: IM def _evaluate(ctx: "Context"): # evaluation will be executed if the task begins or enough train_iter after last evaluation - if ctx.last_eval_iter != -1 and \ + if ctx.last_eval_iter is not None and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): return diff --git a/ding/framework/middleware/functional/exchanger.py b/ding/framework/middleware/functional/exchanger.py new file mode 100644 index 0000000000..d22578c13a --- /dev/null +++ b/ding/framework/middleware/functional/exchanger.py @@ -0,0 +1,77 @@ +import logging +from queue import Empty +from typing import TYPE_CHECKING, List, Dict +from ding.framework import task +from ding.utils.data.structure.lifo_deque import LifoDeque +if TYPE_CHECKING: + from ding.framework.context import Context + from torch.nn import Module + + +def context_exchanger(send_keys: List[str] = None, recv_keys: List[str] = None, skip_n_iter: int = 0): + """ + Overview: + Send data from context in the backward stage. + Buffer received data and wait if not get any data. + Arguments: + - send_keys (:obj:`List[str]`): Keys need to be sent. + - recv_keys (:obj:`List[str]`): Keys need to be received. + - skip_n_iter (:obj:`int`): Whether to skip the first N round of waiting, + e.g. collecting data without waiting for a new model in the first N round, + while training a model that needs to wait for data in the first round. + """ + event_name = "context_exchanger" + + bufferd_payloads = LifoDeque(maxsize=100) + task.on(event_name, lambda payload: bufferd_payloads.put(payload)) + + def _context_exchanger(ctx: "Context"): + if recv_keys: + if ctx.total_step >= skip_n_iter: + payload: Dict = bufferd_payloads.get() + for key in recv_keys: + value = payload.get(key) + if value: + ctx[key] = value + + if send_keys: + yield + payload = {} + for key in send_keys: + payload[key] = ctx.get(key) + if payload: + task.emit(event_name, payload, only_remote=True) + + return _context_exchanger + + +def model_exchanger(model: "Module", is_learner: bool = False): + """ + Overview: + Exchange model between processes, only the learner will send the model, + otherwise the model will only be received. + If you are using a shared model on a single host, there is no need to use this middleware. + Arguments: + - model (:obj:`torch.nn.Module`): Pytorch module. + - is_learner (:obj:`bool`): Whether use this middleware as learner or not. + """ + event_name = "model_exchanger" + bufferd_state_dict = LifoDeque(maxsize=1) + + if not is_learner: + task.on(event_name, lambda state_dict: bufferd_state_dict.put(state_dict)) + + def _model_exchanger(ctx: "Context"): + if not is_learner: + if ctx.total_step != 0: # Skip first iteration + try: + state_dict = bufferd_state_dict.get(timeout=5) + model.load_state_dict(state_dict) + except Empty: + logging.warning("Timeout when waiting for new model!") + + if is_learner: + yield + task.emit(event_name, model.state_dict(), only_remote=True) + + return _model_exchanger diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 469ae7e77f..b61149debf 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -34,6 +34,7 @@ def __init__(self) -> None: def _run( self, node_id: int, + n_parallel_workers: int, labels: Optional[Set[str]] = None, auto_recover: bool = False, max_retries: int = float("inf"), @@ -41,6 +42,7 @@ def _run( **kwargs ) -> None: self.node_id = node_id + self.n_parallel_workers = n_parallel_workers self.labels = labels or set() self.auto_recover = auto_recover self.max_retries = max_retries @@ -128,7 +130,6 @@ def _nng_args_parser( ) -> Dict[str, dict]: attach_to = attach_to or [] nodes = cls.get_node_addrs(n_parallel_workers, protocol=protocol, address=address, ports=ports) - logging.info("Bind subprocesses on these addresses: {}".format(nodes)) def cleanup_nodes(): for node in nodes: @@ -156,6 +157,7 @@ def topology_network(i: int) -> List[str]: "node_id": candidate_node_ids[i], "listen_to": nodes[i], "attach_to": topology_network(i), + "n_parallel_workers": n_parallel_workers, } runner_params.append(runner_kwargs) @@ -166,7 +168,7 @@ def _redis_args_parser(cls, n_parallel_workers: int, node_ids: Optional[Union[Li runner_params = [] candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0) for i in range(n_parallel_workers): - runner_kwargs = {**kwargs, "node_id": candidate_node_ids[i]} + runner_kwargs = {**kwargs, "n_parallel_workers": n_parallel_workers, "node_id": candidate_node_ids[i]} runner_params.append(runner_kwargs) return runner_params diff --git a/ding/utils/data/structure/__init__.py b/ding/utils/data/structure/__init__.py index 9e8011f9d4..3cc58828a6 100644 --- a/ding/utils/data/structure/__init__.py +++ b/ding/utils/data/structure/__init__.py @@ -1 +1,2 @@ from .cache import Cache +from .lifo_deque import LifoDeque diff --git a/ding/utils/data/structure/lifo_deque.py b/ding/utils/data/structure/lifo_deque.py new file mode 100644 index 0000000000..00d9221e5c --- /dev/null +++ b/ding/utils/data/structure/lifo_deque.py @@ -0,0 +1,12 @@ +from queue import LifoQueue +from collections import deque + + +class LifoDeque(LifoQueue): + """ + Like LifoQueue, but automatically replaces the oldest data when the queue is full. + """ + + def _init(self, maxsize): + self.maxsize = maxsize + 1 + self.queue = deque(maxlen=maxsize) From 43f6f01653138245e5efb8436c9ea6a0e129e16b Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 26 May 2022 11:39:43 +0800 Subject: [PATCH 12/70] feature(nyz): add online logger freq --- ding/example/dqn_dist.py | 2 +- ding/framework/middleware/functional/logger.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ding/example/dqn_dist.py b/ding/example/dqn_dist.py index 8d08d3be6d..0fd1e09111 100644 --- a/ding/example/dqn_dist.py +++ b/ding/example/dqn_dist.py @@ -26,7 +26,7 @@ 3. You can repeat step 2 to start more collectors on other machines. """ import gym -import logging +from ditk import logging from ding.model import DQN from ding.policy import DQNPolicy from ding.envs import DingEnvWrapper, BaseEnvManagerV2 diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index bec65cbd1f..915364c9c1 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -6,17 +6,20 @@ from ding.framework import OnlineRLContext -def online_logger(record_train_iter: bool = False) -> Callable: +def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: writer = DistributedWriter.get_instance() + last_train_show_iter = -1 def _logger(ctx: "OnlineRLContext"): + nonlocal last_train_show_iter if ctx.eval_value is not None: if record_train_iter: writer.add_scalar('basic/eval_episode_reward_mean-env_step', ctx.eval_value, ctx.env_step) writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) else: writer.add_scalar('basic/eval_episode_reward_mean', ctx.eval_value, ctx.env_step) - if ctx.train_output is not None: + if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: + last_train_show_iter = ctx.train_iter if isinstance(ctx.train_output, deque): output = ctx.train_output.pop() # only use latest output else: From 1e0c4a14d573536d95febdaa7bddf15080e80d9b Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 1 Jun 2022 21:37:51 +0800 Subject: [PATCH 13/70] fix(nyz): fix policy set device bug --- ding/policy/base_policy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index a30b080015..441d5e7f1a 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -75,7 +75,6 @@ def __init__( if len(set(self._enable_field).intersection(set(['learn']))) > 0: self._rank = get_rank() if self._cfg.learn.multi_gpu else 0 if self._cuda: - torch.cuda.set_device(self._rank % torch.cuda.device_count()) model.cuda() if self._cfg.learn.multi_gpu: bp_update_sync = self._cfg.learn.get('bp_update_sync', True) @@ -84,7 +83,6 @@ def __init__( else: self._rank = 0 if self._cuda: - torch.cuda.set_device(self._rank % torch.cuda.device_count()) model.cuda() self._model = model self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu' From c043ebdccb4cb4a8f454f1aa12ff9177c6bf9843 Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 06:56:37 +0000 Subject: [PATCH 14/70] add offline rl logger --- ding/example/cql.py | 6 ++-- .../middleware/functional/__init__.py | 2 +- .../middleware/functional/evaluator.py | 20 +++++++---- .../framework/middleware/functional/logger.py | 33 +++++++++++++++++-- .../middleware/functional/trainer.py | 18 ++++++---- 5 files changed, 61 insertions(+), 18 deletions(-) diff --git a/ding/example/cql.py b/ding/example/cql.py index 5651121d8d..1e1c678dd0 100644 --- a/ding/example/cql.py +++ b/ding/example/cql.py @@ -5,9 +5,9 @@ from ding.envs import DingEnvWrapper, BaseEnvManagerV2 from ding.data import create_dataset from ding.config import compile_config -from ding.framework import task +from ding.framework import task, ding_init from ding.framework.context import OfflineRLContext -from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger from ding.utils import set_pkg_seed from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv from dizoo.classic_control.pendulum.config.pendulum_cql_config import main_config, create_config @@ -18,6 +18,7 @@ def main(): # For demostration, we also can train a RL policy (e.g. SAC) and collect some data logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with task.start(async_mode=False, ctx=OfflineRLContext()): evaluator_env = BaseEnvManagerV2( env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager @@ -33,6 +34,7 @@ def main(): task.use(offline_data_fetcher(cfg, dataset)) task.use(trainer(cfg, policy.learn_mode)) task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(offline_logger()) task.run() diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 743c61f8e3..b34596b9dc 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -5,7 +5,7 @@ from .evaluator import interaction_evaluator from .termination_checker import termination_checker from .pace_controller import pace_controller -from .logger import online_logger +from .logger import online_logger, offline_logger from .exchanger import context_exchanger, model_exchanger # algorithm diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index b1c9599953..55ce6d197b 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -13,8 +13,7 @@ from ding.torch_utils import tensor_to_list from ding.utils import lists_to_dicts -if TYPE_CHECKING: - from ding.framework import Context, OnlineRLContext +from ding.framework import Context, OnlineRLContext, OfflineRLContext class IMetric(ABC): @@ -157,7 +156,7 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> env.seed(cfg.seed, dynamic_seed=False) - def _evaluate(ctx: "OnlineRLContext"): + def _evaluate(ctx: "Context"): """ Overview: - The evaluation will be executed if the task begins and enough train_iter passed \ @@ -196,11 +195,18 @@ def _evaluate(ctx: "OnlineRLContext"): episode_reward = eval_monitor.get_episode_reward() eval_reward = np.mean(episode_reward) stop_flag = eval_reward >= cfg.env.stop_value and ctx.train_iter > 0 - logging.info( - 'Evaluation: Train Iter({})\tEnv Step({})\tEval Reward({:.3f})'.format( - ctx.train_iter, ctx.env_step, eval_reward + if isinstance(ctx, OnlineRLContext): + logging.info( + 'Evaluation: Train Iter({})\tEnv Step({})\tEval Reward({:.3f})'.format( + ctx.train_iter, ctx.env_step, eval_reward + ) + ) + elif isinstance(ctx, OfflineRLContext): + logging.info( + 'Evaluation: Train Iter({})\tEval Reward({:.3f})'.format( + ctx.train_iter, eval_reward + ) ) - ) ctx.last_eval_iter = ctx.train_iter ctx.eval_value = eval_reward diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index 915364c9c1..104be482bb 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -3,7 +3,7 @@ from ding.utils import DistributedWriter if TYPE_CHECKING: - from ding.framework import OnlineRLContext + from ding.framework import OnlineRLContext, OfflineRLContext def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: @@ -48,4 +48,33 @@ def _logger(ctx: "OnlineRLContext"): return _logger -# TODO offline logger +def offline_logger(record_train_iter: bool = False) -> Callable: + writer = DistributedWriter.get_instance() + + def _logger(ctx: "OfflineRLContext"): + if ctx.eval_value is not None: + if record_train_iter: + writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) + if ctx.train_output is not None: + if isinstance(ctx.train_output, deque): + output = ctx.train_output.pop() # only use latest output + else: + output = ctx.train_output + # TODO(nyz) ppo train log case + if isinstance(output, List): + raise NotImplementedError + for k, v in output.items(): + if k in ['priority']: + continue + if "[scalars]" in k: + new_k = k.split(']')[-1] + raise NotImplementedError + elif "[histogram]" in k: + new_k = k.split(']')[-1] + if record_train_iter: + writer.add_histogram(new_k, v, ctx.train_iter) + else: + if record_train_iter: + writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) + return _logger + diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index 5c7f99467b..f9fd432733 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -5,8 +5,7 @@ from ding.policy import Policy from ding.framework import task -if TYPE_CHECKING: - from ding.framework import OnlineRLContext, OfflineRLContext +from ding.framework import OnlineRLContext, OfflineRLContext def trainer(cfg: EasyDict, policy: Policy) -> Callable: @@ -33,11 +32,18 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): return train_output = policy.forward(ctx.train_data) if ctx.train_iter % cfg.policy.learn.learner.hook.log_show_after_iter == 0: - logging.info( - 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format( - ctx.train_iter, ctx.env_step, train_output['total_loss'] + if isinstance(ctx, OnlineRLContext): + logging.info( + 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format( + ctx.train_iter, ctx.env_step, train_output['total_loss'] + ) ) - ) + elif isinstance(ctx, OfflineRLContext): + logging.info( + 'Training: Train Iter({})\tLoss({:.3f})'.format( + ctx.train_iter, train_output['total_loss'] + ) + ) ctx.train_iter += 1 ctx.train_output = train_output From df8719ac6a3f3c6bd66e1f60ecb30bd3081b3446 Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 08:09:03 +0000 Subject: [PATCH 15/70] change a bit --- .../middleware/functional/evaluator.py | 2 +- .../framework/middleware/functional/logger.py | 19 +++++-------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 55ce6d197b..721e3682d4 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -156,7 +156,7 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> env.seed(cfg.seed, dynamic_seed=False) - def _evaluate(ctx: "Context"): + def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): """ Overview: - The evaluation will be executed if the task begins and enough train_iter passed \ diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index 104be482bb..e1272d5a88 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -48,21 +48,14 @@ def _logger(ctx: "OnlineRLContext"): return _logger -def offline_logger(record_train_iter: bool = False) -> Callable: +def offline_logger() -> Callable: writer = DistributedWriter.get_instance() def _logger(ctx: "OfflineRLContext"): if ctx.eval_value is not None: - if record_train_iter: - writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) + writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) if ctx.train_output is not None: - if isinstance(ctx.train_output, deque): - output = ctx.train_output.pop() # only use latest output - else: - output = ctx.train_output - # TODO(nyz) ppo train log case - if isinstance(output, List): - raise NotImplementedError + output = ctx.train_output for k, v in output.items(): if k in ['priority']: continue @@ -71,10 +64,8 @@ def _logger(ctx: "OfflineRLContext"): raise NotImplementedError elif "[histogram]" in k: new_k = k.split(']')[-1] - if record_train_iter: - writer.add_histogram(new_k, v, ctx.train_iter) + writer.add_histogram(new_k, v, ctx.train_iter) else: - if record_train_iter: - writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) + writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) return _logger From 7ade025275895aab97162db9492d2ae884190703 Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 08:18:15 +0000 Subject: [PATCH 16/70] add else in checking ctx type --- ding/framework/middleware/functional/evaluator.py | 2 ++ ding/framework/middleware/functional/trainer.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 721e3682d4..3c4785053e 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -207,6 +207,8 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): ctx.train_iter, eval_reward ) ) + else: + raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.last_eval_iter = ctx.train_iter ctx.eval_value = eval_reward diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index f9fd432733..3d5db201d0 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -43,7 +43,9 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): 'Training: Train Iter({})\tLoss({:.3f})'.format( ctx.train_iter, train_output['total_loss'] ) - ) + ) + else: + raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.train_iter += 1 ctx.train_output = train_output From fe6a32f485f44a622d9136353dba5a31d53b59df Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 14:45:58 +0000 Subject: [PATCH 17/70] add test_logger.py --- .../framework/middleware/tests/test_logger.py | 142 ++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 ding/framework/middleware/tests/test_logger.py diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py new file mode 100644 index 0000000000..07a5b3ddf1 --- /dev/null +++ b/ding/framework/middleware/tests/test_logger.py @@ -0,0 +1,142 @@ +import pytest +from ding.framework import OnlineRLContext, OfflineRLContext, ding_init +from ding.framework.middleware.functional import online_logger, offline_logger +from easydict import EasyDict +import os +from os import path +import shutil +from collections import deque + +test_folder = "test_exp" +test_path = path.join(os.getcwd(), test_folder) +cfg = EasyDict({"exp_name": "test_exp"}) + +@pytest.fixture(scope='function') +def online_ctx_output_dict(): + ctx = OnlineRLContext() + ctx.eval_value = -10000 + ctx.train_iter = 34 + ctx.env_step = 78 + ctx.train_output = { + 'priority': [107], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 15 + } + return ctx + +@pytest.fixture(scope='function') +def online_ctx_output_deque(): + ctx = OnlineRLContext() + ctx.eval_value = -600 + ctx.train_iter = 24 + ctx.env_step = 30 + ctx.train_output = deque([ + { + 'priority': [107], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 15 + }, + { + 'priority': [108], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 30 + } + ]) + return ctx + +@pytest.fixture(scope='function') +def online_ctx_output_list(): + ctx = OnlineRLContext() + ctx.eval_value = -1000000 + ctx.train_iter = 23232 + ctx.env_step = 33333 + ctx.train_output = [ + { + 'priority': [107], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 15 + }, + { + 'priority': [108], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 30 + } + ] + return ctx + +@pytest.fixture(scope='function') +def online_scalar_ctx(): + ctx = OfflineRLContext() + ctx.eval_value = -777888 + ctx.train_iter = 2233 + ctx.env_step = 32323 + ctx.train_output = { + '[scalars]': 1 + } + return ctx + + +@pytest.mark.zms +class TestOnlineLogger: + + def test_online_logger_output_dict(self, online_ctx_output_dict): + ding_init(cfg) + online_logger()(online_ctx_output_dict) + + def test_online_logger_record_output_dict(self, online_ctx_output_dict): + ding_init(cfg) + online_logger(record_train_iter=True)(online_ctx_output_dict) + + def test_online_logger_record_output_deque(self, online_ctx_output_deque): + ding_init(cfg) + online_logger()(online_ctx_output_deque) + + def test_online_logger_record_output_list(self, online_ctx_output_list): + ding_init(cfg) + with pytest.raises(NotImplementedError) as exc_info: + online_logger()(online_ctx_output_list) + + def test_online_logger_scalars(self, online_scalar_ctx): + ding_init(cfg) + with pytest.raises(NotImplementedError) as exc_info: + online_logger()(online_scalar_ctx) + + +@pytest.fixture(scope='function') +def offline_ctx_output_dict(): + ctx = OfflineRLContext() + ctx.eval_value = -10000000000 + ctx.train_iter = 3323233 + ctx.train_output = { + 'priority': [107], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 15 + } + return ctx + +@pytest.fixture(scope='function') +def offline_scalar_ctx(): + ctx = OfflineRLContext() + ctx.eval_value = -232 + ctx.train_iter = 3333 + ctx.train_output = { + '[scalars]': 1 + } + return ctx + +@pytest.mark.zms +class TestOfflineLogger: + + def test_offline_logger_no_scalars(self, offline_ctx_output_dict): + ding_init(cfg) + offline_logger()(offline_ctx_output_dict) + + def test_offline_logger_scalars(self, offline_scalar_ctx): + ding_init(cfg) + with pytest.raises(NotImplementedError) as exc_info: + offline_logger()(offline_scalar_ctx) + + assert path.exists(test_path) + if path.exists(test_path): + shutil.rmtree(test_path) + From 4754c808bcf737a2d6a81471c7403752862d94a7 Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 15:42:10 +0000 Subject: [PATCH 18/70] add mock of offline_logger --- .../framework/middleware/tests/test_logger.py | 58 +++++++++++++------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index 07a5b3ddf1..af4493918a 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -6,6 +6,9 @@ from os import path import shutil from collections import deque +from unittest.mock import Mock, patch +from ding.utils import DistributedWriter +import copy test_folder = "test_exp" test_path = path.join(os.getcwd(), test_folder) @@ -101,12 +104,10 @@ def test_online_logger_scalars(self, online_scalar_ctx): with pytest.raises(NotImplementedError) as exc_info: online_logger()(online_scalar_ctx) - -@pytest.fixture(scope='function') -def offline_ctx_output_dict(): +def get_offline_ctx(): ctx = OfflineRLContext() ctx.eval_value = -10000000000 - ctx.train_iter = 3323233 + ctx.train_iter = 3333 ctx.train_output = { 'priority': [107], '[histogram]test_histogram': [1,2,3,4,5,6], @@ -114,29 +115,52 @@ def offline_ctx_output_dict(): } return ctx +@pytest.fixture(scope='function') +def offline_ctx_output_dict(): + ctx = get_offline_ctx() + return ctx + @pytest.fixture(scope='function') def offline_scalar_ctx(): - ctx = OfflineRLContext() - ctx.eval_value = -232 - ctx.train_iter = 3333 + ctx = get_offline_ctx() ctx.train_output = { '[scalars]': 1 } return ctx -@pytest.mark.zms +class MockWriter: + + def __init__(self): + self.ctx = get_offline_ctx() + + def add_scalar(self, tag, scalar_value, global_step): + assert global_step == self.ctx.train_iter + if tag == 'basic/eval_episode_reward_mean-train_iter': + assert scalar_value == self.ctx.eval_value + elif tag == 'basic/train_td_error-train_iter': + assert scalar_value == self.ctx.train_output['td_error'] + else: + raise NotImplementedError('tag should be in the tags defined') + + def add_histogram(self, tag, values, global_step): + assert tag == 'test_histogram' + assert values == [1,2,3,4,5,6] + assert global_step == self.ctx.train_iter + +def mock_get_instance(): + return MockWriter() + + +@pytest.mark.offline class TestOfflineLogger: def test_offline_logger_no_scalars(self, offline_ctx_output_dict): - ding_init(cfg) - offline_logger()(offline_ctx_output_dict) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_instance): + offline_logger()(offline_ctx_output_dict) def test_offline_logger_scalars(self, offline_scalar_ctx): - ding_init(cfg) - with pytest.raises(NotImplementedError) as exc_info: - offline_logger()(offline_scalar_ctx) - - assert path.exists(test_path) - if path.exists(test_path): - shutil.rmtree(test_path) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_instance): + with pytest.raises(NotImplementedError) as exc_info: + offline_logger()(offline_scalar_ctx) + From 83869c8f4492c6f124266b97f63ae7ddfd48ae8d Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 16:06:46 +0000 Subject: [PATCH 19/70] add mock of online writer --- .../framework/middleware/tests/test_logger.py | 111 +++++++++--------- 1 file changed, 58 insertions(+), 53 deletions(-) diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index af4493918a..bf2d5df45d 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -14,8 +14,7 @@ test_path = path.join(os.getcwd(), test_folder) cfg = EasyDict({"exp_name": "test_exp"}) -@pytest.fixture(scope='function') -def online_ctx_output_dict(): +def get_online_ctx(): ctx = OnlineRLContext() ctx.eval_value = -10000 ctx.train_iter = 34 @@ -27,82 +26,88 @@ def online_ctx_output_dict(): } return ctx +@pytest.fixture(scope='function') +def online_ctx_output_dict(): + ctx = get_online_ctx() + return ctx + @pytest.fixture(scope='function') def online_ctx_output_deque(): - ctx = OnlineRLContext() - ctx.eval_value = -600 - ctx.train_iter = 24 - ctx.env_step = 30 + ctx = get_online_ctx() ctx.train_output = deque([ - { - 'priority': [107], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 15 - }, - { - 'priority': [108], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 30 - } + ctx.train_output ]) return ctx @pytest.fixture(scope='function') def online_ctx_output_list(): - ctx = OnlineRLContext() - ctx.eval_value = -1000000 - ctx.train_iter = 23232 - ctx.env_step = 33333 + ctx = get_online_ctx() ctx.train_output = [ - { - 'priority': [107], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 15 - }, - { - 'priority': [108], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 30 - } + ctx.train_output ] return ctx @pytest.fixture(scope='function') def online_scalar_ctx(): - ctx = OfflineRLContext() - ctx.eval_value = -777888 - ctx.train_iter = 2233 - ctx.env_step = 32323 + ctx = get_online_ctx() ctx.train_output = { '[scalars]': 1 } return ctx +class MockOnlineWriter: + def __init__(self): + self.ctx = get_online_ctx() + + def add_scalar(self, tag, scalar_value, global_step): + if tag in ['basic/eval_episode_reward_mean-env_step', 'basic/eval_episode_reward_mean']: + assert scalar_value == self.ctx.eval_value + assert global_step == self.ctx.env_step + elif tag == 'basic/eval_episode_reward_mean-train_iter': + assert scalar_value == self.ctx.eval_value + assert global_step == self.ctx.train_iter + elif tag in ['basic/train_td_error-env_step', 'basic/train_td_error']: + assert scalar_value == self.ctx.train_output['td_error'] + assert global_step == self.ctx.env_step + elif tag == 'basic/train_td_error-train_iter': + assert scalar_value == self.ctx.train_output['td_error'] + assert global_step == self.ctx.train_iter + else: + raise NotImplementedError('tag should be in the tags defined') + + def add_histogram(self, tag, values, global_step): + assert tag == 'test_histogram' + assert values == [1,2,3,4,5,6] + assert global_step in [self.ctx.train_iter, self.ctx.env_step] + +def mock_get_online_instance(): + return MockOnlineWriter() -@pytest.mark.zms +@pytest.mark.unittest class TestOnlineLogger: def test_online_logger_output_dict(self, online_ctx_output_dict): - ding_init(cfg) - online_logger()(online_ctx_output_dict) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + online_logger()(online_ctx_output_dict) def test_online_logger_record_output_dict(self, online_ctx_output_dict): - ding_init(cfg) - online_logger(record_train_iter=True)(online_ctx_output_dict) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + online_logger(record_train_iter=True)(online_ctx_output_dict) def test_online_logger_record_output_deque(self, online_ctx_output_deque): - ding_init(cfg) - online_logger()(online_ctx_output_deque) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + online_logger()(online_ctx_output_deque) def test_online_logger_record_output_list(self, online_ctx_output_list): - ding_init(cfg) - with pytest.raises(NotImplementedError) as exc_info: - online_logger()(online_ctx_output_list) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + with pytest.raises(NotImplementedError) as exc_info: + online_logger()(online_ctx_output_list) def test_online_logger_scalars(self, online_scalar_ctx): - ding_init(cfg) - with pytest.raises(NotImplementedError) as exc_info: - online_logger()(online_scalar_ctx) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + with pytest.raises(NotImplementedError) as exc_info: + online_logger()(online_scalar_ctx) + def get_offline_ctx(): ctx = OfflineRLContext() @@ -128,7 +133,7 @@ def offline_scalar_ctx(): } return ctx -class MockWriter: +class MockOfflineWriter: def __init__(self): self.ctx = get_offline_ctx() @@ -147,19 +152,19 @@ def add_histogram(self, tag, values, global_step): assert values == [1,2,3,4,5,6] assert global_step == self.ctx.train_iter -def mock_get_instance(): - return MockWriter() +def mock_get_offline_instance(): + return MockOfflineWriter() -@pytest.mark.offline +@pytest.mark.unittest class TestOfflineLogger: def test_offline_logger_no_scalars(self, offline_ctx_output_dict): - with patch.object(DistributedWriter, 'get_instance', new=mock_get_instance): + with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): offline_logger()(offline_ctx_output_dict) def test_offline_logger_scalars(self, offline_scalar_ctx): - with patch.object(DistributedWriter, 'get_instance', new=mock_get_instance): + with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): with pytest.raises(NotImplementedError) as exc_info: offline_logger()(offline_scalar_ctx) From 302c824393e95e7b2769482ff46f03b0662b4de1 Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 16:09:23 +0000 Subject: [PATCH 20/70] reformat --- .../framework/middleware/tests/test_logger.py | 59 +++++++++---------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index bf2d5df45d..56faf43a62 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -14,51 +14,48 @@ test_path = path.join(os.getcwd(), test_folder) cfg = EasyDict({"exp_name": "test_exp"}) + def get_online_ctx(): ctx = OnlineRLContext() ctx.eval_value = -10000 ctx.train_iter = 34 ctx.env_step = 78 - ctx.train_output = { - 'priority': [107], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 15 - } + ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15} return ctx + @pytest.fixture(scope='function') def online_ctx_output_dict(): ctx = get_online_ctx() return ctx + @pytest.fixture(scope='function') def online_ctx_output_deque(): ctx = get_online_ctx() - ctx.train_output = deque([ - ctx.train_output - ]) + ctx.train_output = deque([ctx.train_output]) return ctx + @pytest.fixture(scope='function') def online_ctx_output_list(): ctx = get_online_ctx() - ctx.train_output = [ - ctx.train_output - ] + ctx.train_output = [ctx.train_output] return ctx + @pytest.fixture(scope='function') def online_scalar_ctx(): ctx = get_online_ctx() - ctx.train_output = { - '[scalars]': 1 - } + ctx.train_output = {'[scalars]': 1} return ctx + class MockOnlineWriter: + def __init__(self): self.ctx = get_online_ctx() - + def add_scalar(self, tag, scalar_value, global_step): if tag in ['basic/eval_episode_reward_mean-env_step', 'basic/eval_episode_reward_mean']: assert scalar_value == self.ctx.eval_value @@ -74,15 +71,17 @@ def add_scalar(self, tag, scalar_value, global_step): assert global_step == self.ctx.train_iter else: raise NotImplementedError('tag should be in the tags defined') - + def add_histogram(self, tag, values, global_step): assert tag == 'test_histogram' - assert values == [1,2,3,4,5,6] + assert values == [1, 2, 3, 4, 5, 6] assert global_step in [self.ctx.train_iter, self.ctx.env_step] + def mock_get_online_instance(): return MockOnlineWriter() + @pytest.mark.unittest class TestOnlineLogger: @@ -97,12 +96,12 @@ def test_online_logger_record_output_dict(self, online_ctx_output_dict): def test_online_logger_record_output_deque(self, online_ctx_output_deque): with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): online_logger()(online_ctx_output_deque) - + def test_online_logger_record_output_list(self, online_ctx_output_list): with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): with pytest.raises(NotImplementedError) as exc_info: online_logger()(online_ctx_output_list) - + def test_online_logger_scalars(self, online_scalar_ctx): with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): with pytest.raises(NotImplementedError) as exc_info: @@ -113,31 +112,28 @@ def get_offline_ctx(): ctx = OfflineRLContext() ctx.eval_value = -10000000000 ctx.train_iter = 3333 - ctx.train_output = { - 'priority': [107], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 15 - } + ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15} return ctx + @pytest.fixture(scope='function') def offline_ctx_output_dict(): ctx = get_offline_ctx() return ctx + @pytest.fixture(scope='function') def offline_scalar_ctx(): ctx = get_offline_ctx() - ctx.train_output = { - '[scalars]': 1 - } + ctx.train_output = {'[scalars]': 1} return ctx + class MockOfflineWriter: def __init__(self): self.ctx = get_offline_ctx() - + def add_scalar(self, tag, scalar_value, global_step): assert global_step == self.ctx.train_iter if tag == 'basic/eval_episode_reward_mean-train_iter': @@ -149,9 +145,10 @@ def add_scalar(self, tag, scalar_value, global_step): def add_histogram(self, tag, values, global_step): assert tag == 'test_histogram' - assert values == [1,2,3,4,5,6] + assert values == [1, 2, 3, 4, 5, 6] assert global_step == self.ctx.train_iter + def mock_get_offline_instance(): return MockOfflineWriter() @@ -162,10 +159,8 @@ class TestOfflineLogger: def test_offline_logger_no_scalars(self, offline_ctx_output_dict): with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): offline_logger()(offline_ctx_output_dict) - + def test_offline_logger_scalars(self, offline_scalar_ctx): with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): with pytest.raises(NotImplementedError) as exc_info: offline_logger()(offline_scalar_ctx) - - From 51e3e0e03394cfb00d73d8103a5eec4e51f3a32d Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 16:13:25 +0000 Subject: [PATCH 21/70] reformat --- ding/framework/middleware/functional/evaluator.py | 6 +----- ding/framework/middleware/functional/logger.py | 2 +- ding/framework/middleware/functional/trainer.py | 6 ++---- ding/framework/middleware/tests/test_logger.py | 10 ---------- 4 files changed, 4 insertions(+), 20 deletions(-) diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 3c4785053e..f06f1b8602 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -202,11 +202,7 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): ) ) elif isinstance(ctx, OfflineRLContext): - logging.info( - 'Evaluation: Train Iter({})\tEval Reward({:.3f})'.format( - ctx.train_iter, eval_reward - ) - ) + logging.info('Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, eval_reward)) else: raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.last_eval_iter = ctx.train_iter diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index e1272d5a88..40e2f23732 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -67,5 +67,5 @@ def _logger(ctx: "OfflineRLContext"): writer.add_histogram(new_k, v, ctx.train_iter) else: writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) - return _logger + return _logger diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index 3d5db201d0..990ff69f30 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -40,12 +40,10 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): ) elif isinstance(ctx, OfflineRLContext): logging.info( - 'Training: Train Iter({})\tLoss({:.3f})'.format( - ctx.train_iter, train_output['total_loss'] - ) + 'Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, train_output['total_loss']) ) else: - raise TypeError("not supported ctx type: {}".format(type(ctx))) + raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.train_iter += 1 ctx.train_output = train_output diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index 56faf43a62..1c0a772f1d 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -97,16 +97,6 @@ def test_online_logger_record_output_deque(self, online_ctx_output_deque): with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): online_logger()(online_ctx_output_deque) - def test_online_logger_record_output_list(self, online_ctx_output_list): - with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): - with pytest.raises(NotImplementedError) as exc_info: - online_logger()(online_ctx_output_list) - - def test_online_logger_scalars(self, online_scalar_ctx): - with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): - with pytest.raises(NotImplementedError) as exc_info: - online_logger()(online_scalar_ctx) - def get_offline_ctx(): ctx = OfflineRLContext() From aa29252bdac8d106373fc9e1f4ab1efa2ae28778 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 7 Jun 2022 21:25:33 +0800 Subject: [PATCH 22/70] feature(nyz): polish atari ddp demo and add dist demo --- .../middleware/functional/__init__.py | 2 +- .../functional/termination_checker.py | 25 ++++++ dizoo/atari/example/atari_dqn_ddp.py | 32 +------ dizoo/atari/example/atari_dqn_dist.py | 85 +++++++++++++++++++ 4 files changed, 115 insertions(+), 29 deletions(-) create mode 100644 dizoo/atari/example/atari_dqn_dist.py diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 743c61f8e3..681933cfa1 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -3,7 +3,7 @@ sqil_data_pusher from .collector import inferencer, rolloutor, TransitionList from .evaluator import interaction_evaluator -from .termination_checker import termination_checker +from .termination_checker import termination_checker, ddp_termination_checker from .pace_controller import pace_controller from .logger import online_logger from .exchanger import context_exchanger, model_exchanger diff --git a/ding/framework/middleware/functional/termination_checker.py b/ding/framework/middleware/functional/termination_checker.py index 58c371d57b..b6879c2a16 100644 --- a/ding/framework/middleware/functional/termination_checker.py +++ b/ding/framework/middleware/functional/termination_checker.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, Union, Callable, Optional import numpy as np +import torch +from ding.utils import broadcast from ding.framework import task if TYPE_CHECKING: @@ -20,3 +22,26 @@ def _check(ctx: Union["OnlineRLContext", "OfflineRLContext"]): task.finish = True return _check + + +def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0): + if rank == 0: + if max_env_step is None: + max_env_step = np.inf + if max_train_iter is None: + max_train_iter = np.inf + + def _check(ctx): + if rank == 0: + if ctx.env_step > max_env_step: + finish = torch.ones(1).long().cuda() + elif ctx.train_iter > max_train_iter: + finish = torch.ones(1).long().cuda() + else: + finish = torch.LongTensor([task.finish]).cuda() + else: + finish = torch.zeros(1).long().cuda() + broadcast(finish, 0) + task.finish = finish.cpu().bool().item() + + return _check diff --git a/dizoo/atari/example/atari_dqn_ddp.py b/dizoo/atari/example/atari_dqn_ddp.py index bc07da3bfb..22e30eba89 100644 --- a/dizoo/atari/example/atari_dqn_ddp.py +++ b/dizoo/atari/example/atari_dqn_ddp.py @@ -6,46 +6,21 @@ from ding.data import DequeBuffer from ding.config import compile_config from ding.utils import DistContext, get_rank -from ding.framework import task +from ding.framework import task, ding_init from ding.framework.context import OnlineRLContext from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ - eps_greedy_handler, CkptSaver, nstep_reward_enhancer + eps_greedy_handler, CkptSaver, nstep_reward_enhancer, online_logger, ddp_termination_checker from ding.utils import set_pkg_seed from dizoo.atari.envs.atari_env import AtariEnv from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config -def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0): - import numpy as np - import torch - from ding.utils import broadcast - if rank == 0: - if max_env_step is None: - max_env_step = np.inf - if max_train_iter is None: - max_train_iter = np.inf - - def _check(ctx): - if rank == 0: - if ctx.env_step > max_env_step: - finish = torch.ones(1).long().cuda() - elif ctx.train_iter > max_train_iter: - finish = torch.ones(1).long().cuda() - else: - finish = torch.LongTensor([task.finish]).cuda() - else: - finish = torch.zeros(1).long().cuda() - broadcast(finish, 0) - task.finish = finish.cpu().bool().item() - - return _check - - def main(): logging.getLogger().setLevel(logging.INFO) main_config.exp_name = 'pong_dqn_seed0_ddp' main_config.policy.learn.multi_gpu = True cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with DistContext(): rank = get_rank() with task.start(async_mode=False, ctx=OnlineRLContext()): @@ -75,6 +50,7 @@ def main(): task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) if rank == 0: task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(online_logger(record_train_iter=True)) task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank)) task.run() diff --git a/dizoo/atari/example/atari_dqn_dist.py b/dizoo/atari/example/atari_dqn_dist.py new file mode 100644 index 0000000000..d692a9f3e3 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dist.py @@ -0,0 +1,85 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ + online_logger +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_ditask_dist' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model) + + if 'learner' in task.router.labels: + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + + elif 'evaluator' in task.router.labels: + logging.info("Evaluator running on node {}".format(task.router.node_id)) + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + task.use(online_logger(record_train_iter=True)) + + elif 'collector' in task.router.labels: + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(termination_checker(max_env_step=int(1e7))) + else: + raise KeyError("invalid router labels: {}".format(task.router.labels)) + + task.run() + + +if __name__ == "__main__": + main() From 16d810773b2d7b9b8568685c8751a3095a26c93f Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 8 Jun 2022 15:20:53 +0800 Subject: [PATCH 23/70] fix(nyz): fix mq listen bug when stop --- ding/framework/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index b61149debf..7003a47840 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -264,7 +264,7 @@ def padding_param(cls, int_or_list: Optional[Union[List[int], int]], n_max: int, def listen(self): self._mq.listen() - while True: + while self._mq is not None: msg = self._mq.recv() # msg is none means that the message queue is no longer being listened to, # especially if the message queue is already closed From 01c9868c1a6ba7435fadace54e90130748fa040a Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 8 Jun 2022 15:23:17 +0800 Subject: [PATCH 24/70] demo(nyz): add atari ppo(sm+ddp) demo --- .../functional/advantage_estimator.py | 4 ++ .../functional/termination_checker.py | 6 +++ dizoo/atari/example/atari_ppo.py | 47 ++++++++++++++++ dizoo/atari/example/atari_ppo_ddp.py | 54 +++++++++++++++++++ 4 files changed, 111 insertions(+) create mode 100644 dizoo/atari/example/atari_ppo.py create mode 100644 dizoo/atari/example/atari_ppo_ddp.py diff --git a/ding/framework/middleware/functional/advantage_estimator.py b/ding/framework/middleware/functional/advantage_estimator.py index 57fe866b2b..3ada84c5dd 100644 --- a/ding/framework/middleware/functional/advantage_estimator.py +++ b/ding/framework/middleware/functional/advantage_estimator.py @@ -42,6 +42,8 @@ def _gae(ctx: "OnlineRLContext"): data = ctx.trajectories # list data = ttorch_collate(data) with torch.no_grad(): + if cfg.policy.cuda: + data = data.cuda() value = model.forward(data.obs, mode='compute_critic')['value'] next_value = model.forward(data.next_obs, mode='compute_critic')['value'] data.value = value @@ -53,6 +55,8 @@ def _gae(ctx: "OnlineRLContext"): # done is bool type when acquired from env.step data_ = gae_data(data.value, next_value, data.reward, data.done.float(), traj_flag.float()) data.adv = gae(data_, cfg.policy.collect.discount_factor, cfg.policy.collect.gae_lambda) + if cfg.policy.cuda: + data = data.cpu() if buffer_ is None: ctx.train_data = data else: diff --git a/ding/framework/middleware/functional/termination_checker.py b/ding/framework/middleware/functional/termination_checker.py index b6879c2a16..3f7cdc0cc4 100644 --- a/ding/framework/middleware/functional/termination_checker.py +++ b/ding/framework/middleware/functional/termination_checker.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, Union, Callable, Optional +from ditk import logging import numpy as np import torch from ding.utils import broadcast @@ -18,8 +19,10 @@ def _check(ctx: Union["OnlineRLContext", "OfflineRLContext"]): # ">" is better than ">=" when taking logger result into consideration if ctx.env_step > max_env_step: task.finish = True + logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step)) if ctx.train_iter > max_train_iter: task.finish = True + logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter)) return _check @@ -35,12 +38,15 @@ def _check(ctx): if rank == 0: if ctx.env_step > max_env_step: finish = torch.ones(1).long().cuda() + logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step)) elif ctx.train_iter > max_train_iter: finish = torch.ones(1).long().cuda() + logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter)) else: finish = torch.LongTensor([task.finish]).cuda() else: finish = torch.zeros(1).long().cuda() + # broadcast finish result to other DDP workers broadcast(finish, 0) task.finish = finish.cpu().bool().item() diff --git a/dizoo/atari/example/atari_ppo.py b/dizoo/atari/example/atari_ppo.py new file mode 100644 index 0000000000..94b99ca8c2 --- /dev/null +++ b/dizoo/atari/example/atari_ppo.py @@ -0,0 +1,47 @@ +from copy import deepcopy +from ditk import logging +from ding.model import VAC +from ding.policy import PPOPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ + gae_estimator, termination_checker +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(gae_estimator(cfg, policy.collect_mode)) + task.use(multistep_trainer(cfg, policy.learn_mode)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(termination_checker(max_env_step=int(1e7))) + task.run() + + +if __name__ == "__main__": + main() diff --git a/dizoo/atari/example/atari_ppo_ddp.py b/dizoo/atari/example/atari_ppo_ddp.py new file mode 100644 index 0000000000..eea7bea0d6 --- /dev/null +++ b/dizoo/atari/example/atari_ppo_ddp.py @@ -0,0 +1,54 @@ +from copy import deepcopy +from ditk import logging +from ding.model import VAC +from ding.policy import PPOPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ + gae_estimator, ddp_termination_checker, online_logger +from ding.utils import set_pkg_seed, DistContext, get_rank +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.example = 'pong_ppo_seed0_ddp' + main_config.policy.learn.multi_gpu = True + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with DistContext(): + rank = get_rank() + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + + if rank == 0: + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(gae_estimator(cfg, policy.collect_mode)) + task.use(multistep_trainer(cfg, policy.learn_mode)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank)) + task.run() + + +if __name__ == "__main__": + main() From dd4e0db35cb0b7c26491eeb4a8eb0bec93a325be Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 8 Jun 2022 16:55:25 +0800 Subject: [PATCH 25/70] demo(nyz): add ppo ddp avgsplit demo --- dizoo/atari/example/atari_ppo_ddp.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/dizoo/atari/example/atari_ppo_ddp.py b/dizoo/atari/example/atari_ppo_ddp.py index eea7bea0d6..e498e03394 100644 --- a/dizoo/atari/example/atari_ppo_ddp.py +++ b/dizoo/atari/example/atari_ppo_ddp.py @@ -9,19 +9,21 @@ from ding.framework.context import OnlineRLContext from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ gae_estimator, ddp_termination_checker, online_logger -from ding.utils import set_pkg_seed, DistContext, get_rank +from ding.utils import set_pkg_seed, DistContext, get_rank, get_world_size from dizoo.atari.envs.atari_env import AtariEnv from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config def main(): logging.getLogger().setLevel(logging.INFO) - main_config.example = 'pong_ppo_seed0_ddp' - main_config.policy.learn.multi_gpu = True - cfg = compile_config(main_config, create_cfg=create_config, auto=True) - ding_init(cfg) with DistContext(): - rank = get_rank() + rank, world_size = get_rank(), get_world_size() + main_config.example = 'pong_ppo_seed0_ddp_avgsplit' + main_config.policy.learn.multi_gpu = True + main_config.policy.learn.batch_size = main_config.policy.learn.batch_size // world_size + main_config.policy.collect.n_sample = main_config.policy.collect.n_sample // world_size + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_cfg = deepcopy(cfg.env) collector_cfg.is_train = True From b35acece6b5447c949142fe0dd0d5123e9557f7d Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 8 Jun 2022 18:29:52 +0800 Subject: [PATCH 26/70] demo(nyz): add ditask + pytorch ddp demo --- ding/entry/cli_ditask.py | 2 + dizoo/atari/example/atari_dqn_dist_ddp.py | 89 +++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 dizoo/atari/example/atari_dqn_dist_ddp.py diff --git a/ding/entry/cli_ditask.py b/ding/entry/cli_ditask.py index 68ec836fe6..f05bd5257a 100644 --- a/ding/entry/cli_ditask.py +++ b/ding/entry/cli_ditask.py @@ -62,6 +62,7 @@ def print_version(ctx: Context, param: Option, value: bool) -> None: @click.option("--redis-host", type=str, help="Redis host.") @click.option("--redis-port", type=int, help="Redis port.") @click.option("-m", "--main", type=str, help="Main function of entry module.") +@click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP") def cli_ditask(*args, **kwargs): return _cli_ditask(*args, **kwargs) @@ -105,6 +106,7 @@ def _cli_ditask( mq_type: str, redis_host: str, redis_port: int, + local_rank: int = 0, platform: str = None, platform_spec: str = None, ): diff --git a/dizoo/atari/example/atari_dqn_dist_ddp.py b/dizoo/atari/example/atari_dqn_dist_ddp.py new file mode 100644 index 0000000000..0ca678a4d2 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dist_ddp.py @@ -0,0 +1,89 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ + online_logger +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_ditask_dist_ddp' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model) + + if 'learner' in task.router.labels: + from ding.utils import DistContext, get_rank + with DistContext(): + rank = get_rank() + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + + elif 'evaluator' in task.router.labels: + logging.info("Evaluator running on node {}".format(task.router.node_id)) + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + task.use(online_logger(record_train_iter=True)) + + elif 'collector' in task.router.labels: + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(termination_checker(max_env_step=int(1e7))) + else: + raise KeyError("invalid router labels: {}".format(task.router.labels)) + + task.run() + + +if __name__ == "__main__": + main() From 7e71cbb51d63e3a0155a4ed81a0c45dd24b40ed4 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 9 Jun 2022 20:54:49 +0800 Subject: [PATCH 27/70] fix(nyz): fix dict-type obs bugs --- ding/envs/env_manager/base_env_manager.py | 2 ++ ding/envs/env_manager/subprocess_env_manager.py | 2 ++ ding/framework/middleware/functional/collector.py | 3 ++- ding/framework/middleware/functional/evaluator.py | 4 ++-- ding/torch_utils/__init__.py | 2 +- ding/torch_utils/data_helper.py | 10 ++++++++++ ding/utils/data/collate_fn.py | 7 ++++--- 7 files changed, 23 insertions(+), 7 deletions(-) diff --git a/ding/envs/env_manager/base_env_manager.py b/ding/envs/env_manager/base_env_manager.py index d27606d397..1c44c16803 100644 --- a/ding/envs/env_manager/base_env_manager.py +++ b/ding/envs/env_manager/base_env_manager.py @@ -414,6 +414,8 @@ def ready_obs(self) -> tnp.array: """ active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] obs = [self._ready_obs[i] for i in active_env] + if isinstance(obs[0], dict): + obs = [tnp.array(o) for o in obs] return tnp.stack(obs) def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py index f1417c66e9..4aa13baa9f 100644 --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -914,6 +914,8 @@ def ready_obs(self) -> tnp.array: time.sleep(0.001) sleep_count += 1 obs = [self._ready_obs[i] for i in self.ready_env] + if isinstance(obs[0], dict): + obs = [tnp.array(o) for o in obs] return tnp.stack(obs) def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: diff --git a/ding/framework/middleware/functional/collector.py b/ding/framework/middleware/functional/collector.py index 0f6f7f6e7d..4d99302903 100644 --- a/ding/framework/middleware/functional/collector.py +++ b/ding/framework/middleware/functional/collector.py @@ -6,6 +6,7 @@ from ding.envs import BaseEnvManager from ding.policy import Policy from ding.framework import task +from ding.torch_utils import get_shape0 if TYPE_CHECKING: from ding.framework import OnlineRLContext @@ -75,7 +76,7 @@ def _inference(ctx: "OnlineRLContext"): ctx.obs = obs # TODO mask necessary rollout - obs = {i: obs[i] for i in range(obs.shape[0])} # TBD + obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD inference_output = policy.forward(obs, **ctx.collect_kwargs) ctx.action = [v['action'].numpy() for v in inference_output.values()] # TBD ctx.inference_output = inference_output diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index f06f1b8602..2d80a3aca0 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -10,7 +10,7 @@ from ding.policy import Policy from ding.data import Dataset, DataLoader from ding.framework import task -from ding.torch_utils import tensor_to_list +from ding.torch_utils import tensor_to_list, get_shape0 from ding.utils import lists_to_dicts from ding.framework import Context, OnlineRLContext, OfflineRLContext @@ -182,7 +182,7 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): while not eval_monitor.is_finished(): obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32) - obs = {i: obs[i] for i in range(obs.shape[0])} # TBD + obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD inference_output = policy.forward(obs) action = [v['action'].numpy() for v in inference_output.values()] # TBD timesteps = env.step(action) diff --git a/ding/torch_utils/__init__.py b/ding/torch_utils/__init__.py index 250c468233..717406e12f 100644 --- a/ding/torch_utils/__init__.py +++ b/ding/torch_utils/__init__.py @@ -1,6 +1,6 @@ from .checkpoint_helper import build_checkpoint_helper, CountVar, auto_checkpoint from .data_helper import to_device, to_tensor, to_ndarray, to_list, to_dtype, same_shape, tensor_to_list, \ - build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, get_null_data + build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, get_null_data, get_shape0 from .distribution import CategoricalPd, CategoricalPdPytorch from .loss import * from .metric import levenshtein_distance, hamming_distance diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 9e0b8e7861..51c4d51664 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -8,6 +8,7 @@ import numpy as np import torch +import treetensor.torch as ttorch def to_device(item: Any, device: str, ignore_keys: list = []) -> Any: @@ -397,3 +398,12 @@ def get_null_data(template: Any, num: int) -> List[Any]: data['reward'].zero_() ret.append(data) return ret + + +def get_shape0(data): + if isinstance(data, torch.Tensor): + return data.shape[0] + elif isinstance(data, ttorch.Tensor): + return list(data.shape.values())[0][0] + else: + raise TypeError("not support type: {}".format(data)) diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index 0b7fcda9cb..98ed173886 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -56,11 +56,12 @@ def default_collate(batch: Sequence, - ret (:obj:`Union[torch.Tensor, Mapping, Sequence]`): the collated data, with batch size into each data field.\ the return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence]. """ - elem = batch[0] - elem_type = type(elem) if isinstance(batch, ttorch.Tensor): return batch.json() + + elem = batch[0] + elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch_gt_131() and torch.utils.data.get_worker_info() is not None: @@ -78,7 +79,7 @@ def default_collate(batch: Sequence, elif isinstance(elem, ttorch.Tensor): ret = ttorch.stack(batch).json() for k in ret: - if len(ret[k].shape) == 2 and ret[k].shape[1] == 1: # reshape (B, 1) -> (B) + if hasattr(ret[k], 'shape') and len(ret[k].shape) >= 2 and ret[k].shape[1] == 1: # reshape (B, 1) -> (B) ret[k] = ret[k].squeeze(1) return ret elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ From 629a5acfbc257c684b9e3bfce93b2a98a32b75d7 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Fri, 10 Jun 2022 15:58:02 +0800 Subject: [PATCH 28/70] fix(nyz): fix get_shape0 bug when nested structure --- ding/torch_utils/data_helper.py | 10 +++++++++- ding/torch_utils/tests/test_data_helper.py | 15 ++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 51c4d51664..d49bb6f957 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -404,6 +404,14 @@ def get_shape0(data): if isinstance(data, torch.Tensor): return data.shape[0] elif isinstance(data, ttorch.Tensor): - return list(data.shape.values())[0][0] + + def fn(t): + item = list(t.values())[0] + if np.isscalar(item[0]): + return item[0] + else: + return fn(item) + + return fn(data.shape) else: raise TypeError("not support type: {}".format(data)) diff --git a/ding/torch_utils/tests/test_data_helper.py b/ding/torch_utils/tests/test_data_helper.py index 218ce59ba7..629d081b0a 100644 --- a/ding/torch_utils/tests/test_data_helper.py +++ b/ding/torch_utils/tests/test_data_helper.py @@ -4,9 +4,10 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader +import treetensor.torch as ttorch from ding.torch_utils import CudaFetcher, to_device, to_dtype, to_tensor, to_ndarray, to_list, \ - tensor_to_list, same_shape, build_log_buffer, get_tensor_data + tensor_to_list, same_shape, build_log_buffer, get_tensor_data, get_shape0 from ding.utils import EasyTimer @@ -132,6 +133,18 @@ def test_get_tensor_data(self): with pytest.raises(TypeError): get_tensor_data(EasyTimer()) + def test_get_shape0(self): + a = { + 'a': { + 'b': torch.randn(4, 3) + }, + 'c': { + 'd': torch.randn(4) + }, + } + a = ttorch.as_tensor(a) + assert get_shape0(a) == 4 + @pytest.mark.unittest def test_log_dict(): From 88c79640a5f7882aeb3e65cc2181f4fb939035ac Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Thu, 9 Jun 2022 15:20:14 +0800 Subject: [PATCH 29/70] Route finish event to all processes in the cluster --- ding/framework/task.py | 4 +-- ding/framework/tests/test_parallel.py | 2 -- ding/framework/tests/test_task.py | 50 +++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/ding/framework/task.py b/ding/framework/task.py index 53e95716b0..d67c0558b4 100644 --- a/ding/framework/task.py +++ b/ding/framework/task.py @@ -330,6 +330,8 @@ def stop(self) -> None: Overview: Stop and cleanup every thing in the runtime of task. """ + if self.router.is_active: + self.emit("finish", True) if self._thread_pool: self._thread_pool.shutdown() self._event_loop.stop() @@ -472,8 +474,6 @@ def finish(self): @finish.setter def finish(self, value: bool): self._finish = value - if self.router.is_active and value is True: - self.emit("finish", value) def _wrap_event_name(self, event: str) -> str: """ diff --git a/ding/framework/tests/test_parallel.py b/ding/framework/tests/test_parallel.py index b042cb3a57..3c7c190f0c 100644 --- a/ding/framework/tests/test_parallel.py +++ b/ding/framework/tests/test_parallel.py @@ -1,9 +1,7 @@ from collections import defaultdict import pytest import time -import os from ding.framework import Parallel -from ding.utils.design_helper import SingletonMetaclass def parallel_main(): diff --git a/ding/framework/tests/test_task.py b/ding/framework/tests/test_task.py index c9d1243b6e..36a80d23f0 100644 --- a/ding/framework/tests/test_task.py +++ b/ding/framework/tests/test_task.py @@ -1,6 +1,7 @@ +import multiprocessing as mp import pytest -from threading import Lock -from time import sleep +from threading import Lock, Thread +from time import sleep, time import random from ding.framework import task, Context, Parallel @@ -331,3 +332,48 @@ def slowest(ctx): task.use(fast, lock=lock) task.run(1) assert task.ctx.result == "slowest" + + +def broadcast_finish_main(): + with task.start(): + + def tick(ctx: Context): + if task.router.node_id == 1 and ctx.total_step == 1: + task.finish = True + sleep(1) + + task.use(tick) + task.run(20) + + +def broadcast_main_target(): + Parallel.runner( + n_parallel_workers=1, protocol="tcp", address="127.0.0.1", topology="mesh", ports=50555 + )(broadcast_finish_main) + + +def broadcast_secondary_target(): + "Start two standalone processes and connect to the main process." + Parallel.runner( + n_parallel_workers=2, + protocol="tcp", + address="127.0.0.1", + topology="alone", + ports=50556, + attach_to=["tcp://127.0.0.1:50555"], + node_ids=[1, 2] + )(broadcast_finish_main) + + +@pytest.mark.unittest +@pytest.mark.timeout(10) +def test_broadcast_finish(): + start = time() + ctx = mp.get_context("spawn") + main_process = ctx.Process(target=broadcast_main_target) + secondary_process = ctx.Process(target=broadcast_secondary_target) + main_process.start() + secondary_process.start() + main_process.join() + secondary_process.join() + assert (time() - start) < 10 From 6ef568f8db393faaf169346843b245d64df1e338 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sat, 14 May 2022 22:57:35 +0800 Subject: [PATCH 30/70] demo(nyz): add naive dp demo --- dizoo/atari/example/atari_dqn_dp.py | 53 +++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 dizoo/atari/example/atari_dqn_dp.py diff --git a/dizoo/atari/example/atari_dqn_dp.py b/dizoo/atari/example/atari_dqn_dp.py new file mode 100644 index 0000000000..b796c70ad6 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dp.py @@ -0,0 +1,53 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.torch_utils import DataParallel +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, nstep_reward_enhancer, termination_checker +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_dp' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + model = DataParallel(model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(termination_checker(max_train_iter=int(1e7))) + task.run() + + +if __name__ == "__main__": + main() From 60d09279f7b41bb680da6c7ff5e20c55c25fe063 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sun, 15 May 2022 19:47:37 +0800 Subject: [PATCH 31/70] demo(nyz): add naive ddp demo --- ding/utils/pytorch_ddp_dist_helper.py | 4 +- dizoo/atari/example/atari_dqn_ddp.py | 83 +++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 dizoo/atari/example/atari_dqn_ddp.py diff --git a/ding/utils/pytorch_ddp_dist_helper.py b/ding/utils/pytorch_ddp_dist_helper.py index 96847be357..3c9d5881fb 100644 --- a/ding/utils/pytorch_ddp_dist_helper.py +++ b/ding/utils/pytorch_ddp_dist_helper.py @@ -114,7 +114,9 @@ def dist_finalize() -> None: Overview: Finalize distributed training resources """ - dist.destroy_process_group() + # This operation usually hangs out so we ignore it temporally. + # dist.destroy_process_group() + pass class DistContext: diff --git a/dizoo/atari/example/atari_dqn_ddp.py b/dizoo/atari/example/atari_dqn_ddp.py new file mode 100644 index 0000000000..d4445bcb66 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_ddp.py @@ -0,0 +1,83 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.utils import DistContext, get_rank +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, nstep_reward_enhancer +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0): + import numpy as np + import torch + from ding.utils import broadcast + if rank == 0: + if max_env_step is None: + max_env_step = np.inf + if max_train_iter is None: + max_train_iter = np.inf + + def _check(ctx): + if rank == 0: + if ctx.env_step > max_env_step: + finish = torch.ones(1).long().cuda() + elif ctx.train_iter > max_train_iter: + finish = torch.ones(1).long().cuda() + else: + finish = torch.zeros(1).long().cuda() + else: + finish = torch.zeros(1).long().cuda() + broadcast(finish, 0) + task.finish = finish.cpu().bool().item() + + return _check + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_ddp' + main_config.policy.learn.multi_gpu = True + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + with DistContext(): + rank = get_rank() + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + if rank == 0: + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(ddp_termination_checker(max_train_iter=int(1e7), rank=rank)) + task.run() + + +if __name__ == "__main__": + main() From 1011fee95511a3f0adbadb9278da6f1963ec9d9f Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sun, 15 May 2022 21:58:49 +0800 Subject: [PATCH 32/70] feature(nyz): add naive tb_logger in new evaluator --- ding/framework/middleware/functional/evaluator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 4d864b3bae..e19561211d 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -156,6 +156,8 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> """ env.seed(cfg.seed, dynamic_seed=False) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(cfg.exp_name) def _evaluate(ctx: "OnlineRLContext"): """ @@ -194,6 +196,8 @@ def _evaluate(ctx: "OnlineRLContext"): eval_monitor.update_reward(env_id, reward) episode_reward = eval_monitor.get_episode_reward() eval_reward = np.mean(episode_reward) + tb_logger.add_scalar('basic/eval_episode_reward_mean-env_step', eval_reward, ctx.env_step) + tb_logger.add_scalar('basic/eval_episode_reward_mean-train_iter', eval_reward, ctx.train_iter) stop_flag = eval_reward >= cfg.env.stop_value and ctx.train_iter > 0 logging.info( 'Evaluation: Train Iter({})\tEnv Step({})\tEval Reward({:.3f})'.format( From b45ec415588ce73157a6583351f43c76eb5dfe57 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 17 May 2022 10:55:19 +0800 Subject: [PATCH 33/70] feature(nyz): add soft update in DQN target network --- ding/policy/dqn.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index 6198c6c9a5..fe2c97f889 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -144,12 +144,22 @@ def _init_learn(self) -> None: # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.learn.target_update_freq} - ) + if 'target_update_freq' in self._cfg.learn: + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.learn.target_update_freq} + ) + elif 'target_theta' in self._cfg.learn: + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.learn.target_theta} + ) + else: + raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta") self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') self._learn_model.reset() self._target_model.reset() @@ -190,7 +200,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: # Target q value with torch.no_grad(): target_q_value = self._target_model.forward(data['next_obs'])['logit'] - # Max q value action (main model) + # Max q value action (main model), i.e. Double DQN target_q_action = self._learn_model.forward(data['next_obs'])['action'] data_n = q_nstep_td_data( From f9240c7cded6ea4da2f3c486a8725f150919ba89 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 17 May 2022 11:03:00 +0800 Subject: [PATCH 34/70] fix(nyz): fix termination env_step bug and eval task.finish broadcast bug --- dizoo/atari/example/atari_dqn.py | 2 +- dizoo/atari/example/atari_dqn_ddp.py | 4 ++-- dizoo/atari/example/atari_dqn_dp.py | 2 +- dizoo/mujoco/example/mujoco_sac.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dizoo/atari/example/atari_dqn.py b/dizoo/atari/example/atari_dqn.py index 4f4be49ecb..1ac9fdc0c8 100644 --- a/dizoo/atari/example/atari_dqn.py +++ b/dizoo/atari/example/atari_dqn.py @@ -42,7 +42,7 @@ def main(): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=1000)) - task.use(termination_checker(max_train_iter=int(1e7))) + task.use(termination_checker(max_env_step=int(1e7))) task.run() diff --git a/dizoo/atari/example/atari_dqn_ddp.py b/dizoo/atari/example/atari_dqn_ddp.py index d4445bcb66..bc07da3bfb 100644 --- a/dizoo/atari/example/atari_dqn_ddp.py +++ b/dizoo/atari/example/atari_dqn_ddp.py @@ -32,7 +32,7 @@ def _check(ctx): elif ctx.train_iter > max_train_iter: finish = torch.ones(1).long().cuda() else: - finish = torch.zeros(1).long().cuda() + finish = torch.LongTensor([task.finish]).cuda() else: finish = torch.zeros(1).long().cuda() broadcast(finish, 0) @@ -75,7 +75,7 @@ def main(): task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) if rank == 0: task.use(CkptSaver(cfg, policy, train_freq=1000)) - task.use(ddp_termination_checker(max_train_iter=int(1e7), rank=rank)) + task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank)) task.run() diff --git a/dizoo/atari/example/atari_dqn_dp.py b/dizoo/atari/example/atari_dqn_dp.py index b796c70ad6..cea9618061 100644 --- a/dizoo/atari/example/atari_dqn_dp.py +++ b/dizoo/atari/example/atari_dqn_dp.py @@ -45,7 +45,7 @@ def main(): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=1000)) - task.use(termination_checker(max_train_iter=int(1e7))) + task.use(termination_checker(max_env_step=int(1e7))) task.run() diff --git a/dizoo/mujoco/example/mujoco_sac.py b/dizoo/mujoco/example/mujoco_sac.py index 0349d8d161..471e4c8f29 100644 --- a/dizoo/mujoco/example/mujoco_sac.py +++ b/dizoo/mujoco/example/mujoco_sac.py @@ -37,7 +37,7 @@ def main(): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=500)) - task.use(termination_checker(max_train_iter=int(3e6))) + task.use(termination_checker(max_env_step=int(3e6))) task.run() From 96e376ed6a68d2a816d85c5fbffe0536abca6139 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Mon, 16 May 2022 12:17:18 +0800 Subject: [PATCH 35/70] Add singleton log writer --- ding/utils/__init__.py | 2 +- ding/utils/log_writer_helper.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 41cc906fb9..e4178b5337 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -13,7 +13,7 @@ K8sLauncher from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock from .log_helper import build_logger, pretty_print, LoggerFactory -from .log_writer_helper import DistributedWriter +from .log_writer_helper import DistributedWriter, distributed_writer from .orchestrator_launcher import OrchestratorLauncher from .profiler_helper import Profiler, register_profiler from .registry_factory import registries, POLICY_REGISTRY, ENV_REGISTRY, LEARNER_REGISTRY, COMM_LEARNER_REGISTRY, \ diff --git a/ding/utils/log_writer_helper.py b/ding/utils/log_writer_helper.py index 4795fdaf71..4a6b36929f 100644 --- a/ding/utils/log_writer_helper.py +++ b/ding/utils/log_writer_helper.py @@ -104,3 +104,8 @@ def _parallel_fn(self: DistributedWriter, *args, **kwargs): for fn_name in ready_to_parallel_fns: if hasattr(DistributedWriter, fn_name): setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name))) + +# Examples: +# In main, `distributed_writer.plugin(task.router, is_writer=True)`, +# In middleware, `distributed_writer.record()` +distributed_writer = DistributedWriter() From ec413df7266c8f91c2878e1f7b60e94004d00e0b Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Mon, 16 May 2022 18:25:15 +0800 Subject: [PATCH 36/70] Use get_instance on writer --- ding/utils/__init__.py | 2 +- ding/utils/log_writer_helper.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index e4178b5337..41cc906fb9 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -13,7 +13,7 @@ K8sLauncher from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock from .log_helper import build_logger, pretty_print, LoggerFactory -from .log_writer_helper import DistributedWriter, distributed_writer +from .log_writer_helper import DistributedWriter from .orchestrator_launcher import OrchestratorLauncher from .profiler_helper import Profiler, register_profiler from .registry_factory import registries, POLICY_REGISTRY, ENV_REGISTRY, LEARNER_REGISTRY, COMM_LEARNER_REGISTRY, \ diff --git a/ding/utils/log_writer_helper.py b/ding/utils/log_writer_helper.py index 4a6b36929f..7efbc32416 100644 --- a/ding/utils/log_writer_helper.py +++ b/ding/utils/log_writer_helper.py @@ -17,6 +17,7 @@ class DistributedWriter(SummaryWriter): The best way is to use it in conjunction with the ``router`` to take advantage of the message \ and event components of the router (see ``writer.plugin``). """ + root = None def __init__(self, *args, **kwargs): self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True @@ -30,6 +31,21 @@ def __init__(self, *args, **kwargs): self._is_writer = False self._lazy_initialized = False + @classmethod + def get_instance(cls, *args, **kwargs) -> "DistributedWriter": + """ + Overview: + Get instance and set the root level instance on the first called. If args and kwargs is none, + this method will return root instance. + """ + if args or kwargs: + ins = cls(*args, **kwargs) + if cls.root is None: + cls.root = ins + return ins + else: + return cls.root + def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter": """ Overview: From 0bb1d779a2034735897301abb3c2421186708c18 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Mon, 16 May 2022 20:31:16 +0800 Subject: [PATCH 37/70] feature(nyz): add general logger middleware --- ding/example/sac.py | 6 ++- ding/framework/__init__.py | 6 +++ ding/framework/context.py | 6 ++- ding/framework/middleware/ckpt_handler.py | 2 +- .../middleware/functional/__init__.py | 1 + .../middleware/functional/evaluator.py | 4 -- .../framework/middleware/functional/logger.py | 48 +++++++++++++++++++ ding/policy/sac.py | 5 +- 8 files changed, 65 insertions(+), 13 deletions(-) create mode 100644 ding/framework/middleware/functional/logger.py diff --git a/ding/example/sac.py b/ding/example/sac.py index b47352c609..863479f964 100644 --- a/ding/example/sac.py +++ b/ding/example/sac.py @@ -5,10 +5,10 @@ from ding.envs import DingEnvWrapper, BaseEnvManagerV2 from ding.data import DequeBuffer from ding.config import compile_config -from ding.framework import task +from ding.framework import task, ding_init from ding.framework.context import OnlineRLContext from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \ - CkptSaver, OffPolicyLearner, termination_checker + CkptSaver, OffPolicyLearner, termination_checker, online_logger from ding.utils import set_pkg_seed from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv from dizoo.classic_control.pendulum.config.pendulum_sac_config import main_config, create_config @@ -17,6 +17,7 @@ def main(): logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2(env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(10)], cfg=cfg.env.manager) evaluator_env = BaseEnvManagerV2(env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(5)], cfg=cfg.env.manager) @@ -35,6 +36,7 @@ def main(): task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=100)) task.use(termination_checker(max_train_iter=10000)) + task.use(online_logger()) task.run() diff --git a/ding/framework/__init__.py b/ding/framework/__init__.py index 4a19f56316..274d6f2364 100644 --- a/ding/framework/__init__.py +++ b/ding/framework/__init__.py @@ -3,3 +3,9 @@ from .parallel import Parallel from .event_loop import EventLoop from .supervisor import Supervisor +from easydict import EasyDict +from ding.utils import DistributedWriter + + +def ding_init(cfg: EasyDict): + DistributedWriter.get_instance(cfg.exp_name) diff --git a/ding/framework/context.py b/ding/framework/context.py index 949e0e7d01..70c1143d0d 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -48,13 +48,14 @@ def __init__(self, *args, **kwargs) -> None: self.env_episode = 0 self.train_iter = 0 self.train_data = None + self.train_output = None # collect self.collect_kwargs = {} self.trajectories = None self.episodes = None self.trajectory_end_idx = [] # eval - self.eval_value = -np.inf + self.eval_value = None self.last_eval_iter = -1 self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter') @@ -70,8 +71,9 @@ def __init__(self, *args, **kwargs) -> None: self.train_epoch = 0 self.train_iter = 0 self.train_data = None + self.train_output = None # eval - self.eval_value = -np.inf + self.eval_value = None self.last_eval_iter = -1 self.keep('train_iter', 'last_eval_iter') diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index 89772e66ee..4657eff8d9 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -54,7 +54,7 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: self.last_save_iter = ctx.train_iter # best eval reward so far - if ctx.eval_value > self.max_eval_value: + if ctx.eval_value and ctx.eval_value > self.max_eval_value: save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) self.max_eval_value = ctx.eval_value diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index defdc3f4c4..463ee84284 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -5,6 +5,7 @@ from .evaluator import interaction_evaluator from .termination_checker import termination_checker from .pace_controller import pace_controller +from .logger import online_logger # algorithm from .explorer import eps_greedy_handler, eps_greedy_masker diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index e19561211d..4d864b3bae 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -156,8 +156,6 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> """ env.seed(cfg.seed, dynamic_seed=False) - from tensorboardX import SummaryWriter - tb_logger = SummaryWriter(cfg.exp_name) def _evaluate(ctx: "OnlineRLContext"): """ @@ -196,8 +194,6 @@ def _evaluate(ctx: "OnlineRLContext"): eval_monitor.update_reward(env_id, reward) episode_reward = eval_monitor.get_episode_reward() eval_reward = np.mean(episode_reward) - tb_logger.add_scalar('basic/eval_episode_reward_mean-env_step', eval_reward, ctx.env_step) - tb_logger.add_scalar('basic/eval_episode_reward_mean-train_iter', eval_reward, ctx.train_iter) stop_flag = eval_reward >= cfg.env.stop_value and ctx.train_iter > 0 logging.info( 'Evaluation: Train Iter({})\tEnv Step({})\tEval Reward({:.3f})'.format( diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py new file mode 100644 index 0000000000..bec65cbd1f --- /dev/null +++ b/ding/framework/middleware/functional/logger.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING, Callable, Dict, List +from collections import deque +from ding.utils import DistributedWriter + +if TYPE_CHECKING: + from ding.framework import OnlineRLContext + + +def online_logger(record_train_iter: bool = False) -> Callable: + writer = DistributedWriter.get_instance() + + def _logger(ctx: "OnlineRLContext"): + if ctx.eval_value is not None: + if record_train_iter: + writer.add_scalar('basic/eval_episode_reward_mean-env_step', ctx.eval_value, ctx.env_step) + writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) + else: + writer.add_scalar('basic/eval_episode_reward_mean', ctx.eval_value, ctx.env_step) + if ctx.train_output is not None: + if isinstance(ctx.train_output, deque): + output = ctx.train_output.pop() # only use latest output + else: + output = ctx.train_output + # TODO(nyz) ppo train log case + if isinstance(output, List): + raise NotImplementedError + for k, v in output.items(): + if k in ['priority']: + continue + if "[scalars]" in k: + new_k = k.split(']')[-1] + raise NotImplementedError + elif "[histogram]" in k: + new_k = k.split(']')[-1] + writer.add_histogram(new_k, v, ctx.env_step) + if record_train_iter: + writer.add_histogram(new_k, v, ctx.train_iter) + else: + if record_train_iter: + writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) + writer.add_scalar('basic/train_{}-env_step'.format(k), v, ctx.env_step) + else: + writer.add_scalar('basic/train_{}'.format(k), v, ctx.env_step) + + return _logger + + +# TODO offline logger diff --git a/ding/policy/sac.py b/ding/policy/sac.py index 2b12b27873..bd6102a565 100644 --- a/ding/policy/sac.py +++ b/ding/policy/sac.py @@ -361,8 +361,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: torch.zeros_like(self._alpha)).requires_grad_() loss_dict['total_loss'] = sum(loss_dict.values()) - info_dict = {} - # ============= # after update # ============= @@ -379,8 +377,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'q_value_2': target_q_value[1].detach().mean().item(), 'target_value': target_value.detach().mean().item(), 'entropy': entropy.item(), - **info_dict, - **loss_dict + #**loss_dict } def _state_dict_learn(self) -> Dict[str, Any]: From c86fb2e1ce46a8211f296ddcd8c914effd6e25ab Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Tue, 26 Apr 2022 15:52:02 +0800 Subject: [PATCH 38/70] Support distributed dqn --- ding/example/dqn_dist.py | 97 +++++++++++++++++++ ding/framework/context.py | 5 +- ding/framework/middleware/ckpt_handler.py | 6 +- .../middleware/functional/__init__.py | 1 + .../middleware/functional/evaluator.py | 5 +- .../middleware/functional/exchanger.py | 73 ++++++++++++++ ding/framework/parallel.py | 5 +- ding/utils/data/structure/__init__.py | 1 + ding/utils/data/structure/lifo_deque.py | 12 +++ 9 files changed, 198 insertions(+), 7 deletions(-) create mode 100644 ding/example/dqn_dist.py create mode 100644 ding/framework/middleware/functional/exchanger.py create mode 100644 ding/utils/data/structure/lifo_deque.py diff --git a/ding/example/dqn_dist.py b/ding/example/dqn_dist.py new file mode 100644 index 0000000000..72516a62c6 --- /dev/null +++ b/ding/example/dqn_dist.py @@ -0,0 +1,97 @@ +""" +The distributed version of DQN pipeline. +With N workers = 1 learner + 1 evaluator + (N-2) actors + +# First Example —— Execute on one machine with multi processes. +Execute 4 processes with 1 learner + 1 evaluator + 2 actors +Remember to keep them connected by mesh to ensure that they can exchange information with each other. + +> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 4 --topology mesh + +# Second Example —— Execute on multiple machines. +1. Execute 1 learner + 1 evaluator on one machine. +> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology mesh --node-ids 0 --ports 50515 + +2. Execute 2 collectors on another machine. (Suppose the ip of the first machine is 127.0.0.1). + Here we use `alone` topology instead of `mesh` because the collectors do not need communicate with each other. + Remember the `node_ids` cannot be duplicated with the learner, evaluator processes. + And remember to set the `ports` (should not conflict with others) and `attach_to` parameters. +> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology alone --node-ids 2 \ + --ports 50517 --attach-to tcp://127.0.0.1:50515,tcp://127.0.0.1:50516 +""" +import gym +import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, BaseEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger +from ding.utils import set_pkg_seed +from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + # cfg.env.stop_value = 99999999 # Don't stop + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + if task.router.node_id == 0: # Learner + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=100)) + + elif task.router.node_id == 1: # Evaluator + logging.info("Evaluator running on node {}".format(task.router.node_id)) + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager + ) + task.use(context_exchanger(recv_keys=["train_iter"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + + else: # Collectors + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + + task.run() + + +if __name__ == "__main__": + main() diff --git a/ding/framework/context.py b/ding/framework/context.py index 70c1143d0d..3ad27a3b83 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -7,6 +7,7 @@ class Context(dict): Context is an object that pass contextual data between middlewares, whose life cycle is only one training iteration. It is a dict that reflect itself, so you can set any properties as you wish. + Note that the initial value of the property must be equal to False. """ def __init__(self, *args, **kwargs) -> None: @@ -56,7 +57,7 @@ def __init__(self, *args, **kwargs) -> None: self.trajectory_end_idx = [] # eval self.eval_value = None - self.last_eval_iter = -1 + self.last_eval_iter = None self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter') @@ -74,6 +75,6 @@ def __init__(self, *args, **kwargs) -> None: self.train_output = None # eval self.eval_value = None - self.last_eval_iter = -1 + self.last_eval_iter = None self.keep('train_iter', 'last_eval_iter') diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index 4657eff8d9..382886e26d 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -17,7 +17,7 @@ class CkptSaver: The class used to save checkpoint data. """ - def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = None): + def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = None, save_finish: bool = True): """ Overview: Initialize the `CkptSaver`. @@ -25,6 +25,7 @@ def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = No - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.exp_name`. - policy (:obj:`Policy`): Policy used to save the checkpoint. - train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data. + - save_finish (:obj:`int`): Whether save final ckpt when ``task.finish = True``. """ self.policy = policy self.train_freq = train_freq @@ -33,6 +34,7 @@ def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = No os.mkdir(self.prefix) self.last_save_iter = 0 self.max_eval_value = -np.inf + self.save_finish = save_finish def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: """ @@ -59,5 +61,5 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: self.max_eval_value = ctx.eval_value # finish - if task.finish: + if task.finish and self.save_finish: save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 463ee84284..743c61f8e3 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -6,6 +6,7 @@ from .termination_checker import termination_checker from .pace_controller import pace_controller from .logger import online_logger +from .exchanger import context_exchanger, model_exchanger # algorithm from .explorer import eps_greedy_handler, eps_greedy_masker diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 4d864b3bae..b1c9599953 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -169,7 +169,8 @@ def _evaluate(ctx: "OnlineRLContext"): - eval_value (:obj:`float`): The average reward in the current evaluation. """ - if ctx.last_eval_iter != -1 and \ + # evaluation will be executed if the task begins or enough train_iter after last evaluation + if ctx.last_eval_iter is not None and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): return @@ -214,7 +215,7 @@ def metric_evaluator(cfg: EasyDict, policy: Policy, dataset: Dataset, metric: IM def _evaluate(ctx: "Context"): # evaluation will be executed if the task begins or enough train_iter after last evaluation - if ctx.last_eval_iter != -1 and \ + if ctx.last_eval_iter is not None and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): return diff --git a/ding/framework/middleware/functional/exchanger.py b/ding/framework/middleware/functional/exchanger.py new file mode 100644 index 0000000000..8710855f40 --- /dev/null +++ b/ding/framework/middleware/functional/exchanger.py @@ -0,0 +1,73 @@ +from time import sleep +from typing import TYPE_CHECKING, List, Dict +from ding.framework import task +from ding.utils.data.structure.lifo_deque import LifoDeque +if TYPE_CHECKING: + from ding.framework.context import Context + from torch.nn import Module + + +def context_exchanger(send_keys: List[str] = None, recv_keys: List[str] = None, skip_n_iter: int = 0): + """ + Overview: + Send data from context in the backward stage. + Buffer received data and wait if not get any data. + Arguments: + - send_keys (:obj:`List[str]`): Keys need to be sent. + - recv_keys (:obj:`List[str]`): Keys need to be received. + - skip_n_iter (:obj:`int`): Whether to skip the first N round of waiting, + e.g. collecting data without waiting for a new model in the first N round, + while training a model that needs to wait for data in the first round. + """ + event_name = "context_exchanger" + + bufferd_payloads = LifoDeque(maxsize=100) + task.on(event_name, lambda payload: bufferd_payloads.put(payload)) + + def _context_exchanger(ctx: "Context"): + if recv_keys: + if ctx.total_step >= skip_n_iter: + payload: Dict = bufferd_payloads.get() + for key in recv_keys: + value = payload.get(key) + if value: + ctx[key] = value + + if send_keys: + yield + payload = {} + for key in send_keys: + payload[key] = ctx.get(key) + if payload: + task.emit(event_name, payload, only_remote=True) + + return _context_exchanger + + +def model_exchanger(model: "Module", is_learner: bool = False): + """ + Overview: + Exchange model between processes, only the learner will send the model, + otherwise the model will only be received. + If you are using a shared model on a single host, there is no need to use this middleware. + Arguments: + - model (:obj:`torch.nn.Module`): Pytorch module. + - is_learner (:obj:`bool`): Whether use this middleware as learner or not. + """ + event_name = "model_exchanger" + bufferd_state_dict = LifoDeque(maxsize=1) + + if not is_learner: + task.on(event_name, lambda state_dict: bufferd_state_dict.put(state_dict)) + + def _model_exchanger(ctx: "Context"): + if not is_learner: + if ctx.total_step != 0: # Skip first iteration + state_dict = bufferd_state_dict.get() + model.load_state_dict(state_dict) + + if is_learner: + yield + task.emit(event_name, model.state_dict(), only_remote=True) + + return _model_exchanger diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 469ae7e77f..6db8808600 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -34,6 +34,7 @@ def __init__(self) -> None: def _run( self, node_id: int, + n_parallel_workers: int, labels: Optional[Set[str]] = None, auto_recover: bool = False, max_retries: int = float("inf"), @@ -41,6 +42,7 @@ def _run( **kwargs ) -> None: self.node_id = node_id + self.n_parallel_workers = n_parallel_workers self.labels = labels or set() self.auto_recover = auto_recover self.max_retries = max_retries @@ -156,6 +158,7 @@ def topology_network(i: int) -> List[str]: "node_id": candidate_node_ids[i], "listen_to": nodes[i], "attach_to": topology_network(i), + "n_parallel_workers": n_parallel_workers, } runner_params.append(runner_kwargs) @@ -166,7 +169,7 @@ def _redis_args_parser(cls, n_parallel_workers: int, node_ids: Optional[Union[Li runner_params = [] candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0) for i in range(n_parallel_workers): - runner_kwargs = {**kwargs, "node_id": candidate_node_ids[i]} + runner_kwargs = {**kwargs, "n_parallel_workers": n_parallel_workers, "node_id": candidate_node_ids[i]} runner_params.append(runner_kwargs) return runner_params diff --git a/ding/utils/data/structure/__init__.py b/ding/utils/data/structure/__init__.py index 9e8011f9d4..3cc58828a6 100644 --- a/ding/utils/data/structure/__init__.py +++ b/ding/utils/data/structure/__init__.py @@ -1 +1,2 @@ from .cache import Cache +from .lifo_deque import LifoDeque diff --git a/ding/utils/data/structure/lifo_deque.py b/ding/utils/data/structure/lifo_deque.py new file mode 100644 index 0000000000..00d9221e5c --- /dev/null +++ b/ding/utils/data/structure/lifo_deque.py @@ -0,0 +1,12 @@ +from queue import LifoQueue +from collections import deque + + +class LifoDeque(LifoQueue): + """ + Like LifoQueue, but automatically replaces the oldest data when the queue is full. + """ + + def _init(self, maxsize): + self.maxsize = maxsize + 1 + self.queue = deque(maxlen=maxsize) From 3da345560a127e3a7ef32c3dae1e0403964be330 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Tue, 26 Apr 2022 16:06:56 +0800 Subject: [PATCH 39/70] Add more desc (ci skip) --- ding/example/dqn_dist.py | 10 ++++++++-- ding/framework/message_queue/nng.py | 1 + ding/framework/parallel.py | 1 - 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ding/example/dqn_dist.py b/ding/example/dqn_dist.py index 72516a62c6..8d08d3be6d 100644 --- a/ding/example/dqn_dist.py +++ b/ding/example/dqn_dist.py @@ -10,14 +10,20 @@ # Second Example —— Execute on multiple machines. 1. Execute 1 learner + 1 evaluator on one machine. + > ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology mesh --node-ids 0 --ports 50515 2. Execute 2 collectors on another machine. (Suppose the ip of the first machine is 127.0.0.1). Here we use `alone` topology instead of `mesh` because the collectors do not need communicate with each other. Remember the `node_ids` cannot be duplicated with the learner, evaluator processes. And remember to set the `ports` (should not conflict with others) and `attach_to` parameters. + The value of the `attach_to` parameter should be obtained from the log of the + process started earlier (e.g. 'NNG listen on tcp://10.0.0.4:50515'). + > ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology alone --node-ids 2 \ - --ports 50517 --attach-to tcp://127.0.0.1:50515,tcp://127.0.0.1:50516 + --ports 50517 --attach-to tcp://10.0.0.4:50515,tcp://127.0.0.1:50516 + +3. You can repeat step 2 to start more collectors on other machines. """ import gym import logging @@ -68,7 +74,7 @@ def main(): env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) - task.use(context_exchanger(recv_keys=["train_iter"], skip_n_iter=1)) + task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) task.use(model_exchanger(model, is_learner=False)) task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(CkptSaver(cfg, policy, save_finish=False)) diff --git a/ding/framework/message_queue/nng.py b/ding/framework/message_queue/nng.py index feab6473a1..fd48c56585 100644 --- a/ding/framework/message_queue/nng.py +++ b/ding/framework/message_queue/nng.py @@ -30,6 +30,7 @@ def listen(self) -> None: sleep(0.1) # Wait for peers to bind for contact in self.attach_to: sock.dial(contact) + logging.info("NNG listen on {}, attach to {}".format(self.listen_to, self.attach_to)) def publish(self, topic: str, data: bytes) -> None: if not self._finished: diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 6db8808600..b61149debf 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -130,7 +130,6 @@ def _nng_args_parser( ) -> Dict[str, dict]: attach_to = attach_to or [] nodes = cls.get_node_addrs(n_parallel_workers, protocol=protocol, address=address, ports=ports) - logging.info("Bind subprocesses on these addresses: {}".format(nodes)) def cleanup_nodes(): for node in nodes: From 205075917c597afdd4456aa87bd6387cacf96da0 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Tue, 26 Apr 2022 15:52:02 +0800 Subject: [PATCH 40/70] Support distributed dqn Add more desc (ci skip) Add timeout on model exchanger --- ding/framework/context.py | 2 ++ ding/framework/middleware/ckpt_handler.py | 2 +- ding/framework/middleware/functional/exchanger.py | 10 +++++++--- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ding/framework/context.py b/ding/framework/context.py index 3ad27a3b83..6b12ef9e70 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -55,6 +55,8 @@ def __init__(self, *args, **kwargs) -> None: self.trajectories = None self.episodes = None self.trajectory_end_idx = [] + self.action = [] + self.inference_output = None # eval self.eval_value = None self.last_eval_iter = None diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index 382886e26d..b07331d719 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -56,7 +56,7 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: self.last_save_iter = ctx.train_iter # best eval reward so far - if ctx.eval_value and ctx.eval_value > self.max_eval_value: + if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value: save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) self.max_eval_value = ctx.eval_value diff --git a/ding/framework/middleware/functional/exchanger.py b/ding/framework/middleware/functional/exchanger.py index 8710855f40..d22578c13a 100644 --- a/ding/framework/middleware/functional/exchanger.py +++ b/ding/framework/middleware/functional/exchanger.py @@ -1,4 +1,5 @@ -from time import sleep +import logging +from queue import Empty from typing import TYPE_CHECKING, List, Dict from ding.framework import task from ding.utils.data.structure.lifo_deque import LifoDeque @@ -63,8 +64,11 @@ def model_exchanger(model: "Module", is_learner: bool = False): def _model_exchanger(ctx: "Context"): if not is_learner: if ctx.total_step != 0: # Skip first iteration - state_dict = bufferd_state_dict.get() - model.load_state_dict(state_dict) + try: + state_dict = bufferd_state_dict.get(timeout=5) + model.load_state_dict(state_dict) + except Empty: + logging.warning("Timeout when waiting for new model!") if is_learner: yield From 09c99a95dbc4148c0101c6f4e0841c10347c1163 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 26 May 2022 11:39:43 +0800 Subject: [PATCH 41/70] feature(nyz): add online logger freq --- ding/example/dqn_dist.py | 2 +- ding/framework/middleware/functional/logger.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ding/example/dqn_dist.py b/ding/example/dqn_dist.py index 8d08d3be6d..0fd1e09111 100644 --- a/ding/example/dqn_dist.py +++ b/ding/example/dqn_dist.py @@ -26,7 +26,7 @@ 3. You can repeat step 2 to start more collectors on other machines. """ import gym -import logging +from ditk import logging from ding.model import DQN from ding.policy import DQNPolicy from ding.envs import DingEnvWrapper, BaseEnvManagerV2 diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index bec65cbd1f..915364c9c1 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -6,17 +6,20 @@ from ding.framework import OnlineRLContext -def online_logger(record_train_iter: bool = False) -> Callable: +def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: writer = DistributedWriter.get_instance() + last_train_show_iter = -1 def _logger(ctx: "OnlineRLContext"): + nonlocal last_train_show_iter if ctx.eval_value is not None: if record_train_iter: writer.add_scalar('basic/eval_episode_reward_mean-env_step', ctx.eval_value, ctx.env_step) writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) else: writer.add_scalar('basic/eval_episode_reward_mean', ctx.eval_value, ctx.env_step) - if ctx.train_output is not None: + if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: + last_train_show_iter = ctx.train_iter if isinstance(ctx.train_output, deque): output = ctx.train_output.pop() # only use latest output else: From 9c414004c1ea640697fa35907a15c9be8e404a43 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 1 Jun 2022 21:37:51 +0800 Subject: [PATCH 42/70] fix(nyz): fix policy set device bug --- ding/policy/base_policy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index a30b080015..441d5e7f1a 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -75,7 +75,6 @@ def __init__( if len(set(self._enable_field).intersection(set(['learn']))) > 0: self._rank = get_rank() if self._cfg.learn.multi_gpu else 0 if self._cuda: - torch.cuda.set_device(self._rank % torch.cuda.device_count()) model.cuda() if self._cfg.learn.multi_gpu: bp_update_sync = self._cfg.learn.get('bp_update_sync', True) @@ -84,7 +83,6 @@ def __init__( else: self._rank = 0 if self._cuda: - torch.cuda.set_device(self._rank % torch.cuda.device_count()) model.cuda() self._model = model self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu' From 869e63a1d9d336009fdc6d400eadd810daa761ae Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 06:56:37 +0000 Subject: [PATCH 43/70] add offline rl logger --- ding/example/cql.py | 6 ++-- .../middleware/functional/__init__.py | 2 +- .../middleware/functional/evaluator.py | 20 +++++++---- .../framework/middleware/functional/logger.py | 33 +++++++++++++++++-- .../middleware/functional/trainer.py | 18 ++++++---- 5 files changed, 61 insertions(+), 18 deletions(-) diff --git a/ding/example/cql.py b/ding/example/cql.py index 5651121d8d..1e1c678dd0 100644 --- a/ding/example/cql.py +++ b/ding/example/cql.py @@ -5,9 +5,9 @@ from ding.envs import DingEnvWrapper, BaseEnvManagerV2 from ding.data import create_dataset from ding.config import compile_config -from ding.framework import task +from ding.framework import task, ding_init from ding.framework.context import OfflineRLContext -from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger from ding.utils import set_pkg_seed from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv from dizoo.classic_control.pendulum.config.pendulum_cql_config import main_config, create_config @@ -18,6 +18,7 @@ def main(): # For demostration, we also can train a RL policy (e.g. SAC) and collect some data logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with task.start(async_mode=False, ctx=OfflineRLContext()): evaluator_env = BaseEnvManagerV2( env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager @@ -33,6 +34,7 @@ def main(): task.use(offline_data_fetcher(cfg, dataset)) task.use(trainer(cfg, policy.learn_mode)) task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(offline_logger()) task.run() diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 743c61f8e3..b34596b9dc 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -5,7 +5,7 @@ from .evaluator import interaction_evaluator from .termination_checker import termination_checker from .pace_controller import pace_controller -from .logger import online_logger +from .logger import online_logger, offline_logger from .exchanger import context_exchanger, model_exchanger # algorithm diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index b1c9599953..55ce6d197b 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -13,8 +13,7 @@ from ding.torch_utils import tensor_to_list from ding.utils import lists_to_dicts -if TYPE_CHECKING: - from ding.framework import Context, OnlineRLContext +from ding.framework import Context, OnlineRLContext, OfflineRLContext class IMetric(ABC): @@ -157,7 +156,7 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> env.seed(cfg.seed, dynamic_seed=False) - def _evaluate(ctx: "OnlineRLContext"): + def _evaluate(ctx: "Context"): """ Overview: - The evaluation will be executed if the task begins and enough train_iter passed \ @@ -196,11 +195,18 @@ def _evaluate(ctx: "OnlineRLContext"): episode_reward = eval_monitor.get_episode_reward() eval_reward = np.mean(episode_reward) stop_flag = eval_reward >= cfg.env.stop_value and ctx.train_iter > 0 - logging.info( - 'Evaluation: Train Iter({})\tEnv Step({})\tEval Reward({:.3f})'.format( - ctx.train_iter, ctx.env_step, eval_reward + if isinstance(ctx, OnlineRLContext): + logging.info( + 'Evaluation: Train Iter({})\tEnv Step({})\tEval Reward({:.3f})'.format( + ctx.train_iter, ctx.env_step, eval_reward + ) + ) + elif isinstance(ctx, OfflineRLContext): + logging.info( + 'Evaluation: Train Iter({})\tEval Reward({:.3f})'.format( + ctx.train_iter, eval_reward + ) ) - ) ctx.last_eval_iter = ctx.train_iter ctx.eval_value = eval_reward diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index 915364c9c1..104be482bb 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -3,7 +3,7 @@ from ding.utils import DistributedWriter if TYPE_CHECKING: - from ding.framework import OnlineRLContext + from ding.framework import OnlineRLContext, OfflineRLContext def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: @@ -48,4 +48,33 @@ def _logger(ctx: "OnlineRLContext"): return _logger -# TODO offline logger +def offline_logger(record_train_iter: bool = False) -> Callable: + writer = DistributedWriter.get_instance() + + def _logger(ctx: "OfflineRLContext"): + if ctx.eval_value is not None: + if record_train_iter: + writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) + if ctx.train_output is not None: + if isinstance(ctx.train_output, deque): + output = ctx.train_output.pop() # only use latest output + else: + output = ctx.train_output + # TODO(nyz) ppo train log case + if isinstance(output, List): + raise NotImplementedError + for k, v in output.items(): + if k in ['priority']: + continue + if "[scalars]" in k: + new_k = k.split(']')[-1] + raise NotImplementedError + elif "[histogram]" in k: + new_k = k.split(']')[-1] + if record_train_iter: + writer.add_histogram(new_k, v, ctx.train_iter) + else: + if record_train_iter: + writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) + return _logger + diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index 5c7f99467b..f9fd432733 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -5,8 +5,7 @@ from ding.policy import Policy from ding.framework import task -if TYPE_CHECKING: - from ding.framework import OnlineRLContext, OfflineRLContext +from ding.framework import OnlineRLContext, OfflineRLContext def trainer(cfg: EasyDict, policy: Policy) -> Callable: @@ -33,11 +32,18 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): return train_output = policy.forward(ctx.train_data) if ctx.train_iter % cfg.policy.learn.learner.hook.log_show_after_iter == 0: - logging.info( - 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format( - ctx.train_iter, ctx.env_step, train_output['total_loss'] + if isinstance(ctx, OnlineRLContext): + logging.info( + 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format( + ctx.train_iter, ctx.env_step, train_output['total_loss'] + ) ) - ) + elif isinstance(ctx, OfflineRLContext): + logging.info( + 'Training: Train Iter({})\tLoss({:.3f})'.format( + ctx.train_iter, train_output['total_loss'] + ) + ) ctx.train_iter += 1 ctx.train_output = train_output From 5d0fafd5bf0e59edeff7e7156cf30b87c9b55110 Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 08:09:03 +0000 Subject: [PATCH 44/70] change a bit --- .../middleware/functional/evaluator.py | 2 +- .../framework/middleware/functional/logger.py | 19 +++++-------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 55ce6d197b..721e3682d4 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -156,7 +156,7 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> env.seed(cfg.seed, dynamic_seed=False) - def _evaluate(ctx: "Context"): + def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): """ Overview: - The evaluation will be executed if the task begins and enough train_iter passed \ diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index 104be482bb..e1272d5a88 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -48,21 +48,14 @@ def _logger(ctx: "OnlineRLContext"): return _logger -def offline_logger(record_train_iter: bool = False) -> Callable: +def offline_logger() -> Callable: writer = DistributedWriter.get_instance() def _logger(ctx: "OfflineRLContext"): if ctx.eval_value is not None: - if record_train_iter: - writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) + writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) if ctx.train_output is not None: - if isinstance(ctx.train_output, deque): - output = ctx.train_output.pop() # only use latest output - else: - output = ctx.train_output - # TODO(nyz) ppo train log case - if isinstance(output, List): - raise NotImplementedError + output = ctx.train_output for k, v in output.items(): if k in ['priority']: continue @@ -71,10 +64,8 @@ def _logger(ctx: "OfflineRLContext"): raise NotImplementedError elif "[histogram]" in k: new_k = k.split(']')[-1] - if record_train_iter: - writer.add_histogram(new_k, v, ctx.train_iter) + writer.add_histogram(new_k, v, ctx.train_iter) else: - if record_train_iter: - writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) + writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) return _logger From 3bd5ba54037e33cf6cd11f49bc1dcb594f7dcbe0 Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 08:18:15 +0000 Subject: [PATCH 45/70] add else in checking ctx type --- ding/framework/middleware/functional/evaluator.py | 2 ++ ding/framework/middleware/functional/trainer.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 721e3682d4..3c4785053e 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -207,6 +207,8 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): ctx.train_iter, eval_reward ) ) + else: + raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.last_eval_iter = ctx.train_iter ctx.eval_value = eval_reward diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index f9fd432733..3d5db201d0 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -43,7 +43,9 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): 'Training: Train Iter({})\tLoss({:.3f})'.format( ctx.train_iter, train_output['total_loss'] ) - ) + ) + else: + raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.train_iter += 1 ctx.train_output = train_output From fcaf1ce2ba0ea37bb5e7ef3d385ac871cb4d5fbd Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 14:45:58 +0000 Subject: [PATCH 46/70] add test_logger.py --- .../framework/middleware/tests/test_logger.py | 142 ++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 ding/framework/middleware/tests/test_logger.py diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py new file mode 100644 index 0000000000..07a5b3ddf1 --- /dev/null +++ b/ding/framework/middleware/tests/test_logger.py @@ -0,0 +1,142 @@ +import pytest +from ding.framework import OnlineRLContext, OfflineRLContext, ding_init +from ding.framework.middleware.functional import online_logger, offline_logger +from easydict import EasyDict +import os +from os import path +import shutil +from collections import deque + +test_folder = "test_exp" +test_path = path.join(os.getcwd(), test_folder) +cfg = EasyDict({"exp_name": "test_exp"}) + +@pytest.fixture(scope='function') +def online_ctx_output_dict(): + ctx = OnlineRLContext() + ctx.eval_value = -10000 + ctx.train_iter = 34 + ctx.env_step = 78 + ctx.train_output = { + 'priority': [107], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 15 + } + return ctx + +@pytest.fixture(scope='function') +def online_ctx_output_deque(): + ctx = OnlineRLContext() + ctx.eval_value = -600 + ctx.train_iter = 24 + ctx.env_step = 30 + ctx.train_output = deque([ + { + 'priority': [107], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 15 + }, + { + 'priority': [108], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 30 + } + ]) + return ctx + +@pytest.fixture(scope='function') +def online_ctx_output_list(): + ctx = OnlineRLContext() + ctx.eval_value = -1000000 + ctx.train_iter = 23232 + ctx.env_step = 33333 + ctx.train_output = [ + { + 'priority': [107], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 15 + }, + { + 'priority': [108], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 30 + } + ] + return ctx + +@pytest.fixture(scope='function') +def online_scalar_ctx(): + ctx = OfflineRLContext() + ctx.eval_value = -777888 + ctx.train_iter = 2233 + ctx.env_step = 32323 + ctx.train_output = { + '[scalars]': 1 + } + return ctx + + +@pytest.mark.zms +class TestOnlineLogger: + + def test_online_logger_output_dict(self, online_ctx_output_dict): + ding_init(cfg) + online_logger()(online_ctx_output_dict) + + def test_online_logger_record_output_dict(self, online_ctx_output_dict): + ding_init(cfg) + online_logger(record_train_iter=True)(online_ctx_output_dict) + + def test_online_logger_record_output_deque(self, online_ctx_output_deque): + ding_init(cfg) + online_logger()(online_ctx_output_deque) + + def test_online_logger_record_output_list(self, online_ctx_output_list): + ding_init(cfg) + with pytest.raises(NotImplementedError) as exc_info: + online_logger()(online_ctx_output_list) + + def test_online_logger_scalars(self, online_scalar_ctx): + ding_init(cfg) + with pytest.raises(NotImplementedError) as exc_info: + online_logger()(online_scalar_ctx) + + +@pytest.fixture(scope='function') +def offline_ctx_output_dict(): + ctx = OfflineRLContext() + ctx.eval_value = -10000000000 + ctx.train_iter = 3323233 + ctx.train_output = { + 'priority': [107], + '[histogram]test_histogram': [1,2,3,4,5,6], + 'td_error': 15 + } + return ctx + +@pytest.fixture(scope='function') +def offline_scalar_ctx(): + ctx = OfflineRLContext() + ctx.eval_value = -232 + ctx.train_iter = 3333 + ctx.train_output = { + '[scalars]': 1 + } + return ctx + +@pytest.mark.zms +class TestOfflineLogger: + + def test_offline_logger_no_scalars(self, offline_ctx_output_dict): + ding_init(cfg) + offline_logger()(offline_ctx_output_dict) + + def test_offline_logger_scalars(self, offline_scalar_ctx): + ding_init(cfg) + with pytest.raises(NotImplementedError) as exc_info: + offline_logger()(offline_scalar_ctx) + + assert path.exists(test_path) + if path.exists(test_path): + shutil.rmtree(test_path) + From 52e3500bba6bb4013b422e95d19993b34021861a Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 15:42:10 +0000 Subject: [PATCH 47/70] add mock of offline_logger --- .../framework/middleware/tests/test_logger.py | 58 +++++++++++++------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index 07a5b3ddf1..af4493918a 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -6,6 +6,9 @@ from os import path import shutil from collections import deque +from unittest.mock import Mock, patch +from ding.utils import DistributedWriter +import copy test_folder = "test_exp" test_path = path.join(os.getcwd(), test_folder) @@ -101,12 +104,10 @@ def test_online_logger_scalars(self, online_scalar_ctx): with pytest.raises(NotImplementedError) as exc_info: online_logger()(online_scalar_ctx) - -@pytest.fixture(scope='function') -def offline_ctx_output_dict(): +def get_offline_ctx(): ctx = OfflineRLContext() ctx.eval_value = -10000000000 - ctx.train_iter = 3323233 + ctx.train_iter = 3333 ctx.train_output = { 'priority': [107], '[histogram]test_histogram': [1,2,3,4,5,6], @@ -114,29 +115,52 @@ def offline_ctx_output_dict(): } return ctx +@pytest.fixture(scope='function') +def offline_ctx_output_dict(): + ctx = get_offline_ctx() + return ctx + @pytest.fixture(scope='function') def offline_scalar_ctx(): - ctx = OfflineRLContext() - ctx.eval_value = -232 - ctx.train_iter = 3333 + ctx = get_offline_ctx() ctx.train_output = { '[scalars]': 1 } return ctx -@pytest.mark.zms +class MockWriter: + + def __init__(self): + self.ctx = get_offline_ctx() + + def add_scalar(self, tag, scalar_value, global_step): + assert global_step == self.ctx.train_iter + if tag == 'basic/eval_episode_reward_mean-train_iter': + assert scalar_value == self.ctx.eval_value + elif tag == 'basic/train_td_error-train_iter': + assert scalar_value == self.ctx.train_output['td_error'] + else: + raise NotImplementedError('tag should be in the tags defined') + + def add_histogram(self, tag, values, global_step): + assert tag == 'test_histogram' + assert values == [1,2,3,4,5,6] + assert global_step == self.ctx.train_iter + +def mock_get_instance(): + return MockWriter() + + +@pytest.mark.offline class TestOfflineLogger: def test_offline_logger_no_scalars(self, offline_ctx_output_dict): - ding_init(cfg) - offline_logger()(offline_ctx_output_dict) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_instance): + offline_logger()(offline_ctx_output_dict) def test_offline_logger_scalars(self, offline_scalar_ctx): - ding_init(cfg) - with pytest.raises(NotImplementedError) as exc_info: - offline_logger()(offline_scalar_ctx) - - assert path.exists(test_path) - if path.exists(test_path): - shutil.rmtree(test_path) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_instance): + with pytest.raises(NotImplementedError) as exc_info: + offline_logger()(offline_scalar_ctx) + From 48106f1ad86541ec9d19db9fd8e79d0d0c06de77 Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 16:06:46 +0000 Subject: [PATCH 48/70] add mock of online writer --- .../framework/middleware/tests/test_logger.py | 111 +++++++++--------- 1 file changed, 58 insertions(+), 53 deletions(-) diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index af4493918a..bf2d5df45d 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -14,8 +14,7 @@ test_path = path.join(os.getcwd(), test_folder) cfg = EasyDict({"exp_name": "test_exp"}) -@pytest.fixture(scope='function') -def online_ctx_output_dict(): +def get_online_ctx(): ctx = OnlineRLContext() ctx.eval_value = -10000 ctx.train_iter = 34 @@ -27,82 +26,88 @@ def online_ctx_output_dict(): } return ctx +@pytest.fixture(scope='function') +def online_ctx_output_dict(): + ctx = get_online_ctx() + return ctx + @pytest.fixture(scope='function') def online_ctx_output_deque(): - ctx = OnlineRLContext() - ctx.eval_value = -600 - ctx.train_iter = 24 - ctx.env_step = 30 + ctx = get_online_ctx() ctx.train_output = deque([ - { - 'priority': [107], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 15 - }, - { - 'priority': [108], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 30 - } + ctx.train_output ]) return ctx @pytest.fixture(scope='function') def online_ctx_output_list(): - ctx = OnlineRLContext() - ctx.eval_value = -1000000 - ctx.train_iter = 23232 - ctx.env_step = 33333 + ctx = get_online_ctx() ctx.train_output = [ - { - 'priority': [107], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 15 - }, - { - 'priority': [108], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 30 - } + ctx.train_output ] return ctx @pytest.fixture(scope='function') def online_scalar_ctx(): - ctx = OfflineRLContext() - ctx.eval_value = -777888 - ctx.train_iter = 2233 - ctx.env_step = 32323 + ctx = get_online_ctx() ctx.train_output = { '[scalars]': 1 } return ctx +class MockOnlineWriter: + def __init__(self): + self.ctx = get_online_ctx() + + def add_scalar(self, tag, scalar_value, global_step): + if tag in ['basic/eval_episode_reward_mean-env_step', 'basic/eval_episode_reward_mean']: + assert scalar_value == self.ctx.eval_value + assert global_step == self.ctx.env_step + elif tag == 'basic/eval_episode_reward_mean-train_iter': + assert scalar_value == self.ctx.eval_value + assert global_step == self.ctx.train_iter + elif tag in ['basic/train_td_error-env_step', 'basic/train_td_error']: + assert scalar_value == self.ctx.train_output['td_error'] + assert global_step == self.ctx.env_step + elif tag == 'basic/train_td_error-train_iter': + assert scalar_value == self.ctx.train_output['td_error'] + assert global_step == self.ctx.train_iter + else: + raise NotImplementedError('tag should be in the tags defined') + + def add_histogram(self, tag, values, global_step): + assert tag == 'test_histogram' + assert values == [1,2,3,4,5,6] + assert global_step in [self.ctx.train_iter, self.ctx.env_step] + +def mock_get_online_instance(): + return MockOnlineWriter() -@pytest.mark.zms +@pytest.mark.unittest class TestOnlineLogger: def test_online_logger_output_dict(self, online_ctx_output_dict): - ding_init(cfg) - online_logger()(online_ctx_output_dict) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + online_logger()(online_ctx_output_dict) def test_online_logger_record_output_dict(self, online_ctx_output_dict): - ding_init(cfg) - online_logger(record_train_iter=True)(online_ctx_output_dict) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + online_logger(record_train_iter=True)(online_ctx_output_dict) def test_online_logger_record_output_deque(self, online_ctx_output_deque): - ding_init(cfg) - online_logger()(online_ctx_output_deque) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + online_logger()(online_ctx_output_deque) def test_online_logger_record_output_list(self, online_ctx_output_list): - ding_init(cfg) - with pytest.raises(NotImplementedError) as exc_info: - online_logger()(online_ctx_output_list) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + with pytest.raises(NotImplementedError) as exc_info: + online_logger()(online_ctx_output_list) def test_online_logger_scalars(self, online_scalar_ctx): - ding_init(cfg) - with pytest.raises(NotImplementedError) as exc_info: - online_logger()(online_scalar_ctx) + with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): + with pytest.raises(NotImplementedError) as exc_info: + online_logger()(online_scalar_ctx) + def get_offline_ctx(): ctx = OfflineRLContext() @@ -128,7 +133,7 @@ def offline_scalar_ctx(): } return ctx -class MockWriter: +class MockOfflineWriter: def __init__(self): self.ctx = get_offline_ctx() @@ -147,19 +152,19 @@ def add_histogram(self, tag, values, global_step): assert values == [1,2,3,4,5,6] assert global_step == self.ctx.train_iter -def mock_get_instance(): - return MockWriter() +def mock_get_offline_instance(): + return MockOfflineWriter() -@pytest.mark.offline +@pytest.mark.unittest class TestOfflineLogger: def test_offline_logger_no_scalars(self, offline_ctx_output_dict): - with patch.object(DistributedWriter, 'get_instance', new=mock_get_instance): + with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): offline_logger()(offline_ctx_output_dict) def test_offline_logger_scalars(self, offline_scalar_ctx): - with patch.object(DistributedWriter, 'get_instance', new=mock_get_instance): + with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): with pytest.raises(NotImplementedError) as exc_info: offline_logger()(offline_scalar_ctx) From 0de3bed3be2f34c84e3781cbc254ffc3691a62c7 Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 16:09:23 +0000 Subject: [PATCH 49/70] reformat --- .../framework/middleware/tests/test_logger.py | 59 +++++++++---------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index bf2d5df45d..56faf43a62 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -14,51 +14,48 @@ test_path = path.join(os.getcwd(), test_folder) cfg = EasyDict({"exp_name": "test_exp"}) + def get_online_ctx(): ctx = OnlineRLContext() ctx.eval_value = -10000 ctx.train_iter = 34 ctx.env_step = 78 - ctx.train_output = { - 'priority': [107], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 15 - } + ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15} return ctx + @pytest.fixture(scope='function') def online_ctx_output_dict(): ctx = get_online_ctx() return ctx + @pytest.fixture(scope='function') def online_ctx_output_deque(): ctx = get_online_ctx() - ctx.train_output = deque([ - ctx.train_output - ]) + ctx.train_output = deque([ctx.train_output]) return ctx + @pytest.fixture(scope='function') def online_ctx_output_list(): ctx = get_online_ctx() - ctx.train_output = [ - ctx.train_output - ] + ctx.train_output = [ctx.train_output] return ctx + @pytest.fixture(scope='function') def online_scalar_ctx(): ctx = get_online_ctx() - ctx.train_output = { - '[scalars]': 1 - } + ctx.train_output = {'[scalars]': 1} return ctx + class MockOnlineWriter: + def __init__(self): self.ctx = get_online_ctx() - + def add_scalar(self, tag, scalar_value, global_step): if tag in ['basic/eval_episode_reward_mean-env_step', 'basic/eval_episode_reward_mean']: assert scalar_value == self.ctx.eval_value @@ -74,15 +71,17 @@ def add_scalar(self, tag, scalar_value, global_step): assert global_step == self.ctx.train_iter else: raise NotImplementedError('tag should be in the tags defined') - + def add_histogram(self, tag, values, global_step): assert tag == 'test_histogram' - assert values == [1,2,3,4,5,6] + assert values == [1, 2, 3, 4, 5, 6] assert global_step in [self.ctx.train_iter, self.ctx.env_step] + def mock_get_online_instance(): return MockOnlineWriter() + @pytest.mark.unittest class TestOnlineLogger: @@ -97,12 +96,12 @@ def test_online_logger_record_output_dict(self, online_ctx_output_dict): def test_online_logger_record_output_deque(self, online_ctx_output_deque): with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): online_logger()(online_ctx_output_deque) - + def test_online_logger_record_output_list(self, online_ctx_output_list): with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): with pytest.raises(NotImplementedError) as exc_info: online_logger()(online_ctx_output_list) - + def test_online_logger_scalars(self, online_scalar_ctx): with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): with pytest.raises(NotImplementedError) as exc_info: @@ -113,31 +112,28 @@ def get_offline_ctx(): ctx = OfflineRLContext() ctx.eval_value = -10000000000 ctx.train_iter = 3333 - ctx.train_output = { - 'priority': [107], - '[histogram]test_histogram': [1,2,3,4,5,6], - 'td_error': 15 - } + ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15} return ctx + @pytest.fixture(scope='function') def offline_ctx_output_dict(): ctx = get_offline_ctx() return ctx + @pytest.fixture(scope='function') def offline_scalar_ctx(): ctx = get_offline_ctx() - ctx.train_output = { - '[scalars]': 1 - } + ctx.train_output = {'[scalars]': 1} return ctx + class MockOfflineWriter: def __init__(self): self.ctx = get_offline_ctx() - + def add_scalar(self, tag, scalar_value, global_step): assert global_step == self.ctx.train_iter if tag == 'basic/eval_episode_reward_mean-train_iter': @@ -149,9 +145,10 @@ def add_scalar(self, tag, scalar_value, global_step): def add_histogram(self, tag, values, global_step): assert tag == 'test_histogram' - assert values == [1,2,3,4,5,6] + assert values == [1, 2, 3, 4, 5, 6] assert global_step == self.ctx.train_iter + def mock_get_offline_instance(): return MockOfflineWriter() @@ -162,10 +159,8 @@ class TestOfflineLogger: def test_offline_logger_no_scalars(self, offline_ctx_output_dict): with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): offline_logger()(offline_ctx_output_dict) - + def test_offline_logger_scalars(self, offline_scalar_ctx): with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): with pytest.raises(NotImplementedError) as exc_info: offline_logger()(offline_scalar_ctx) - - From 3063e47050efa0773abec5e4a06af06b52fb3e4d Mon Sep 17 00:00:00 2001 From: zhumengshen <744762298@qq.com> Date: Mon, 6 Jun 2022 16:13:25 +0000 Subject: [PATCH 50/70] reformat --- ding/framework/middleware/functional/evaluator.py | 6 +----- ding/framework/middleware/functional/logger.py | 2 +- ding/framework/middleware/functional/trainer.py | 6 ++---- ding/framework/middleware/tests/test_logger.py | 10 ---------- 4 files changed, 4 insertions(+), 20 deletions(-) diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 3c4785053e..f06f1b8602 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -202,11 +202,7 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): ) ) elif isinstance(ctx, OfflineRLContext): - logging.info( - 'Evaluation: Train Iter({})\tEval Reward({:.3f})'.format( - ctx.train_iter, eval_reward - ) - ) + logging.info('Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, eval_reward)) else: raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.last_eval_iter = ctx.train_iter diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index e1272d5a88..40e2f23732 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -67,5 +67,5 @@ def _logger(ctx: "OfflineRLContext"): writer.add_histogram(new_k, v, ctx.train_iter) else: writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) - return _logger + return _logger diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index 3d5db201d0..990ff69f30 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -40,12 +40,10 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): ) elif isinstance(ctx, OfflineRLContext): logging.info( - 'Training: Train Iter({})\tLoss({:.3f})'.format( - ctx.train_iter, train_output['total_loss'] - ) + 'Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, train_output['total_loss']) ) else: - raise TypeError("not supported ctx type: {}".format(type(ctx))) + raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.train_iter += 1 ctx.train_output = train_output diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index 56faf43a62..1c0a772f1d 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -97,16 +97,6 @@ def test_online_logger_record_output_deque(self, online_ctx_output_deque): with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): online_logger()(online_ctx_output_deque) - def test_online_logger_record_output_list(self, online_ctx_output_list): - with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): - with pytest.raises(NotImplementedError) as exc_info: - online_logger()(online_ctx_output_list) - - def test_online_logger_scalars(self, online_scalar_ctx): - with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): - with pytest.raises(NotImplementedError) as exc_info: - online_logger()(online_scalar_ctx) - def get_offline_ctx(): ctx = OfflineRLContext() From d60118528222a8e17268e877dd7f76889f2b8170 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 7 Jun 2022 21:25:33 +0800 Subject: [PATCH 51/70] feature(nyz): polish atari ddp demo and add dist demo --- .../middleware/functional/__init__.py | 2 +- .../functional/termination_checker.py | 25 ++++++ dizoo/atari/example/atari_dqn_ddp.py | 32 +------ dizoo/atari/example/atari_dqn_dist.py | 85 +++++++++++++++++++ 4 files changed, 115 insertions(+), 29 deletions(-) create mode 100644 dizoo/atari/example/atari_dqn_dist.py diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index b34596b9dc..a925b48494 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -3,7 +3,7 @@ sqil_data_pusher from .collector import inferencer, rolloutor, TransitionList from .evaluator import interaction_evaluator -from .termination_checker import termination_checker +from .termination_checker import termination_checker, ddp_termination_checker from .pace_controller import pace_controller from .logger import online_logger, offline_logger from .exchanger import context_exchanger, model_exchanger diff --git a/ding/framework/middleware/functional/termination_checker.py b/ding/framework/middleware/functional/termination_checker.py index 58c371d57b..b6879c2a16 100644 --- a/ding/framework/middleware/functional/termination_checker.py +++ b/ding/framework/middleware/functional/termination_checker.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, Union, Callable, Optional import numpy as np +import torch +from ding.utils import broadcast from ding.framework import task if TYPE_CHECKING: @@ -20,3 +22,26 @@ def _check(ctx: Union["OnlineRLContext", "OfflineRLContext"]): task.finish = True return _check + + +def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0): + if rank == 0: + if max_env_step is None: + max_env_step = np.inf + if max_train_iter is None: + max_train_iter = np.inf + + def _check(ctx): + if rank == 0: + if ctx.env_step > max_env_step: + finish = torch.ones(1).long().cuda() + elif ctx.train_iter > max_train_iter: + finish = torch.ones(1).long().cuda() + else: + finish = torch.LongTensor([task.finish]).cuda() + else: + finish = torch.zeros(1).long().cuda() + broadcast(finish, 0) + task.finish = finish.cpu().bool().item() + + return _check diff --git a/dizoo/atari/example/atari_dqn_ddp.py b/dizoo/atari/example/atari_dqn_ddp.py index bc07da3bfb..22e30eba89 100644 --- a/dizoo/atari/example/atari_dqn_ddp.py +++ b/dizoo/atari/example/atari_dqn_ddp.py @@ -6,46 +6,21 @@ from ding.data import DequeBuffer from ding.config import compile_config from ding.utils import DistContext, get_rank -from ding.framework import task +from ding.framework import task, ding_init from ding.framework.context import OnlineRLContext from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ - eps_greedy_handler, CkptSaver, nstep_reward_enhancer + eps_greedy_handler, CkptSaver, nstep_reward_enhancer, online_logger, ddp_termination_checker from ding.utils import set_pkg_seed from dizoo.atari.envs.atari_env import AtariEnv from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config -def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0): - import numpy as np - import torch - from ding.utils import broadcast - if rank == 0: - if max_env_step is None: - max_env_step = np.inf - if max_train_iter is None: - max_train_iter = np.inf - - def _check(ctx): - if rank == 0: - if ctx.env_step > max_env_step: - finish = torch.ones(1).long().cuda() - elif ctx.train_iter > max_train_iter: - finish = torch.ones(1).long().cuda() - else: - finish = torch.LongTensor([task.finish]).cuda() - else: - finish = torch.zeros(1).long().cuda() - broadcast(finish, 0) - task.finish = finish.cpu().bool().item() - - return _check - - def main(): logging.getLogger().setLevel(logging.INFO) main_config.exp_name = 'pong_dqn_seed0_ddp' main_config.policy.learn.multi_gpu = True cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with DistContext(): rank = get_rank() with task.start(async_mode=False, ctx=OnlineRLContext()): @@ -75,6 +50,7 @@ def main(): task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) if rank == 0: task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(online_logger(record_train_iter=True)) task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank)) task.run() diff --git a/dizoo/atari/example/atari_dqn_dist.py b/dizoo/atari/example/atari_dqn_dist.py new file mode 100644 index 0000000000..d692a9f3e3 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dist.py @@ -0,0 +1,85 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ + online_logger +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_ditask_dist' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model) + + if 'learner' in task.router.labels: + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + + elif 'evaluator' in task.router.labels: + logging.info("Evaluator running on node {}".format(task.router.node_id)) + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + task.use(online_logger(record_train_iter=True)) + + elif 'collector' in task.router.labels: + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(termination_checker(max_env_step=int(1e7))) + else: + raise KeyError("invalid router labels: {}".format(task.router.labels)) + + task.run() + + +if __name__ == "__main__": + main() From be27ff7a66674859f4abb01656ce40205b628c90 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 8 Jun 2022 15:20:53 +0800 Subject: [PATCH 52/70] fix(nyz): fix mq listen bug when stop --- ding/framework/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index b61149debf..7003a47840 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -264,7 +264,7 @@ def padding_param(cls, int_or_list: Optional[Union[List[int], int]], n_max: int, def listen(self): self._mq.listen() - while True: + while self._mq is not None: msg = self._mq.recv() # msg is none means that the message queue is no longer being listened to, # especially if the message queue is already closed From e5868dfc3da5f0d9608a7e09181344fe31ba8304 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 8 Jun 2022 15:23:17 +0800 Subject: [PATCH 53/70] demo(nyz): add atari ppo(sm+ddp) demo --- .../functional/advantage_estimator.py | 4 ++ .../functional/termination_checker.py | 6 +++ dizoo/atari/example/atari_ppo.py | 47 ++++++++++++++++ dizoo/atari/example/atari_ppo_ddp.py | 54 +++++++++++++++++++ 4 files changed, 111 insertions(+) create mode 100644 dizoo/atari/example/atari_ppo.py create mode 100644 dizoo/atari/example/atari_ppo_ddp.py diff --git a/ding/framework/middleware/functional/advantage_estimator.py b/ding/framework/middleware/functional/advantage_estimator.py index 57fe866b2b..3ada84c5dd 100644 --- a/ding/framework/middleware/functional/advantage_estimator.py +++ b/ding/framework/middleware/functional/advantage_estimator.py @@ -42,6 +42,8 @@ def _gae(ctx: "OnlineRLContext"): data = ctx.trajectories # list data = ttorch_collate(data) with torch.no_grad(): + if cfg.policy.cuda: + data = data.cuda() value = model.forward(data.obs, mode='compute_critic')['value'] next_value = model.forward(data.next_obs, mode='compute_critic')['value'] data.value = value @@ -53,6 +55,8 @@ def _gae(ctx: "OnlineRLContext"): # done is bool type when acquired from env.step data_ = gae_data(data.value, next_value, data.reward, data.done.float(), traj_flag.float()) data.adv = gae(data_, cfg.policy.collect.discount_factor, cfg.policy.collect.gae_lambda) + if cfg.policy.cuda: + data = data.cpu() if buffer_ is None: ctx.train_data = data else: diff --git a/ding/framework/middleware/functional/termination_checker.py b/ding/framework/middleware/functional/termination_checker.py index b6879c2a16..3f7cdc0cc4 100644 --- a/ding/framework/middleware/functional/termination_checker.py +++ b/ding/framework/middleware/functional/termination_checker.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, Union, Callable, Optional +from ditk import logging import numpy as np import torch from ding.utils import broadcast @@ -18,8 +19,10 @@ def _check(ctx: Union["OnlineRLContext", "OfflineRLContext"]): # ">" is better than ">=" when taking logger result into consideration if ctx.env_step > max_env_step: task.finish = True + logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step)) if ctx.train_iter > max_train_iter: task.finish = True + logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter)) return _check @@ -35,12 +38,15 @@ def _check(ctx): if rank == 0: if ctx.env_step > max_env_step: finish = torch.ones(1).long().cuda() + logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step)) elif ctx.train_iter > max_train_iter: finish = torch.ones(1).long().cuda() + logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter)) else: finish = torch.LongTensor([task.finish]).cuda() else: finish = torch.zeros(1).long().cuda() + # broadcast finish result to other DDP workers broadcast(finish, 0) task.finish = finish.cpu().bool().item() diff --git a/dizoo/atari/example/atari_ppo.py b/dizoo/atari/example/atari_ppo.py new file mode 100644 index 0000000000..94b99ca8c2 --- /dev/null +++ b/dizoo/atari/example/atari_ppo.py @@ -0,0 +1,47 @@ +from copy import deepcopy +from ditk import logging +from ding.model import VAC +from ding.policy import PPOPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ + gae_estimator, termination_checker +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(gae_estimator(cfg, policy.collect_mode)) + task.use(multistep_trainer(cfg, policy.learn_mode)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(termination_checker(max_env_step=int(1e7))) + task.run() + + +if __name__ == "__main__": + main() diff --git a/dizoo/atari/example/atari_ppo_ddp.py b/dizoo/atari/example/atari_ppo_ddp.py new file mode 100644 index 0000000000..eea7bea0d6 --- /dev/null +++ b/dizoo/atari/example/atari_ppo_ddp.py @@ -0,0 +1,54 @@ +from copy import deepcopy +from ditk import logging +from ding.model import VAC +from ding.policy import PPOPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ + gae_estimator, ddp_termination_checker, online_logger +from ding.utils import set_pkg_seed, DistContext, get_rank +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.example = 'pong_ppo_seed0_ddp' + main_config.policy.learn.multi_gpu = True + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with DistContext(): + rank = get_rank() + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + + if rank == 0: + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(gae_estimator(cfg, policy.collect_mode)) + task.use(multistep_trainer(cfg, policy.learn_mode)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank)) + task.run() + + +if __name__ == "__main__": + main() From d637d2bfe46e5c2996a847c4e7d6649b646598d9 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 8 Jun 2022 16:55:25 +0800 Subject: [PATCH 54/70] demo(nyz): add ppo ddp avgsplit demo --- dizoo/atari/example/atari_ppo_ddp.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/dizoo/atari/example/atari_ppo_ddp.py b/dizoo/atari/example/atari_ppo_ddp.py index eea7bea0d6..e498e03394 100644 --- a/dizoo/atari/example/atari_ppo_ddp.py +++ b/dizoo/atari/example/atari_ppo_ddp.py @@ -9,19 +9,21 @@ from ding.framework.context import OnlineRLContext from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ gae_estimator, ddp_termination_checker, online_logger -from ding.utils import set_pkg_seed, DistContext, get_rank +from ding.utils import set_pkg_seed, DistContext, get_rank, get_world_size from dizoo.atari.envs.atari_env import AtariEnv from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config def main(): logging.getLogger().setLevel(logging.INFO) - main_config.example = 'pong_ppo_seed0_ddp' - main_config.policy.learn.multi_gpu = True - cfg = compile_config(main_config, create_cfg=create_config, auto=True) - ding_init(cfg) with DistContext(): - rank = get_rank() + rank, world_size = get_rank(), get_world_size() + main_config.example = 'pong_ppo_seed0_ddp_avgsplit' + main_config.policy.learn.multi_gpu = True + main_config.policy.learn.batch_size = main_config.policy.learn.batch_size // world_size + main_config.policy.collect.n_sample = main_config.policy.collect.n_sample // world_size + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_cfg = deepcopy(cfg.env) collector_cfg.is_train = True From f72b6c93625284ce9f76b50662da3a7735b80d7d Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 8 Jun 2022 18:29:52 +0800 Subject: [PATCH 55/70] demo(nyz): add ditask + pytorch ddp demo --- ding/entry/cli_ditask.py | 2 + dizoo/atari/example/atari_dqn_dist_ddp.py | 89 +++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 dizoo/atari/example/atari_dqn_dist_ddp.py diff --git a/ding/entry/cli_ditask.py b/ding/entry/cli_ditask.py index 68ec836fe6..f05bd5257a 100644 --- a/ding/entry/cli_ditask.py +++ b/ding/entry/cli_ditask.py @@ -62,6 +62,7 @@ def print_version(ctx: Context, param: Option, value: bool) -> None: @click.option("--redis-host", type=str, help="Redis host.") @click.option("--redis-port", type=int, help="Redis port.") @click.option("-m", "--main", type=str, help="Main function of entry module.") +@click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP") def cli_ditask(*args, **kwargs): return _cli_ditask(*args, **kwargs) @@ -105,6 +106,7 @@ def _cli_ditask( mq_type: str, redis_host: str, redis_port: int, + local_rank: int = 0, platform: str = None, platform_spec: str = None, ): diff --git a/dizoo/atari/example/atari_dqn_dist_ddp.py b/dizoo/atari/example/atari_dqn_dist_ddp.py new file mode 100644 index 0000000000..0ca678a4d2 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dist_ddp.py @@ -0,0 +1,89 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ + online_logger +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_ditask_dist_ddp' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model) + + if 'learner' in task.router.labels: + from ding.utils import DistContext, get_rank + with DistContext(): + rank = get_rank() + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + + elif 'evaluator' in task.router.labels: + logging.info("Evaluator running on node {}".format(task.router.node_id)) + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + task.use(online_logger(record_train_iter=True)) + + elif 'collector' in task.router.labels: + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(termination_checker(max_env_step=int(1e7))) + else: + raise KeyError("invalid router labels: {}".format(task.router.labels)) + + task.run() + + +if __name__ == "__main__": + main() From 84447855befdd99a245fae2ac4a973f936f7cd96 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 9 Jun 2022 20:54:49 +0800 Subject: [PATCH 56/70] fix(nyz): fix dict-type obs bugs --- ding/envs/env_manager/base_env_manager.py | 2 ++ ding/envs/env_manager/subprocess_env_manager.py | 2 ++ ding/framework/middleware/functional/collector.py | 3 ++- ding/framework/middleware/functional/evaluator.py | 4 ++-- ding/torch_utils/__init__.py | 2 +- ding/torch_utils/data_helper.py | 10 ++++++++++ ding/utils/data/collate_fn.py | 7 ++++--- 7 files changed, 23 insertions(+), 7 deletions(-) diff --git a/ding/envs/env_manager/base_env_manager.py b/ding/envs/env_manager/base_env_manager.py index b335db8dd2..f0f92c0e3e 100644 --- a/ding/envs/env_manager/base_env_manager.py +++ b/ding/envs/env_manager/base_env_manager.py @@ -430,6 +430,8 @@ def ready_obs(self) -> tnp.array: """ active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] obs = [self._ready_obs[i] for i in active_env] + if isinstance(obs[0], dict): + obs = [tnp.array(o) for o in obs] return tnp.stack(obs) def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py index b0dfb415ec..9815b26c2a 100644 --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -933,6 +933,8 @@ def ready_obs(self) -> tnp.array: time.sleep(0.001) sleep_count += 1 obs = [self._ready_obs[i] for i in self.ready_env] + if isinstance(obs[0], dict): + obs = [tnp.array(o) for o in obs] return tnp.stack(obs) def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: diff --git a/ding/framework/middleware/functional/collector.py b/ding/framework/middleware/functional/collector.py index ffa2a01f58..6b1e876fe0 100644 --- a/ding/framework/middleware/functional/collector.py +++ b/ding/framework/middleware/functional/collector.py @@ -4,6 +4,7 @@ import treetensor.torch as ttorch from ding.envs import BaseEnvManager from ding.policy import Policy +from ding.torch_utils import get_shape0 if TYPE_CHECKING: from ding.framework import OnlineRLContext @@ -73,7 +74,7 @@ def _inference(ctx: "OnlineRLContext"): ctx.obs = obs # TODO mask necessary rollout - obs = {i: obs[i] for i in range(obs.shape[0])} # TBD + obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD inference_output = policy.forward(obs, **ctx.collect_kwargs) ctx.action = [v['action'].numpy() for v in inference_output.values()] # TBD ctx.inference_output = inference_output diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index f06f1b8602..2d80a3aca0 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -10,7 +10,7 @@ from ding.policy import Policy from ding.data import Dataset, DataLoader from ding.framework import task -from ding.torch_utils import tensor_to_list +from ding.torch_utils import tensor_to_list, get_shape0 from ding.utils import lists_to_dicts from ding.framework import Context, OnlineRLContext, OfflineRLContext @@ -182,7 +182,7 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): while not eval_monitor.is_finished(): obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32) - obs = {i: obs[i] for i in range(obs.shape[0])} # TBD + obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD inference_output = policy.forward(obs) action = [v['action'].numpy() for v in inference_output.values()] # TBD timesteps = env.step(action) diff --git a/ding/torch_utils/__init__.py b/ding/torch_utils/__init__.py index ca48761136..8e942f416a 100644 --- a/ding/torch_utils/__init__.py +++ b/ding/torch_utils/__init__.py @@ -1,6 +1,6 @@ from .checkpoint_helper import build_checkpoint_helper, CountVar, auto_checkpoint from .data_helper import to_device, to_tensor, to_ndarray, to_list, to_dtype, same_shape, tensor_to_list, \ - build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, get_null_data + build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, get_null_data, get_shape0 from .distribution import CategoricalPd, CategoricalPdPytorch from .metric import levenshtein_distance, hamming_distance from .network import * diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 9e0b8e7861..51c4d51664 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -8,6 +8,7 @@ import numpy as np import torch +import treetensor.torch as ttorch def to_device(item: Any, device: str, ignore_keys: list = []) -> Any: @@ -397,3 +398,12 @@ def get_null_data(template: Any, num: int) -> List[Any]: data['reward'].zero_() ret.append(data) return ret + + +def get_shape0(data): + if isinstance(data, torch.Tensor): + return data.shape[0] + elif isinstance(data, ttorch.Tensor): + return list(data.shape.values())[0][0] + else: + raise TypeError("not support type: {}".format(data)) diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index e9782496c1..758b239416 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -56,11 +56,12 @@ def default_collate(batch: Sequence, - ret (:obj:`Union[torch.Tensor, Mapping, Sequence]`): the collated data, with batch size into each data field.\ the return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence]. """ - elem = batch[0] - elem_type = type(elem) if isinstance(batch, ttorch.Tensor): return batch.json() + + elem = batch[0] + elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch_ge_131() and torch.utils.data.get_worker_info() is not None: @@ -78,7 +79,7 @@ def default_collate(batch: Sequence, elif isinstance(elem, ttorch.Tensor): ret = ttorch.stack(batch).json() for k in ret: - if len(ret[k].shape) == 2 and ret[k].shape[1] == 1: # reshape (B, 1) -> (B) + if hasattr(ret[k], 'shape') and len(ret[k].shape) >= 2 and ret[k].shape[1] == 1: # reshape (B, 1) -> (B) ret[k] = ret[k].squeeze(1) return ret elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ From d0b1c0097ef271530b7437dbe5d75ece438e9436 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Fri, 10 Jun 2022 15:58:02 +0800 Subject: [PATCH 57/70] fix(nyz): fix get_shape0 bug when nested structure --- ding/torch_utils/data_helper.py | 10 +++++++++- ding/torch_utils/tests/test_data_helper.py | 15 ++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 51c4d51664..d49bb6f957 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -404,6 +404,14 @@ def get_shape0(data): if isinstance(data, torch.Tensor): return data.shape[0] elif isinstance(data, ttorch.Tensor): - return list(data.shape.values())[0][0] + + def fn(t): + item = list(t.values())[0] + if np.isscalar(item[0]): + return item[0] + else: + return fn(item) + + return fn(data.shape) else: raise TypeError("not support type: {}".format(data)) diff --git a/ding/torch_utils/tests/test_data_helper.py b/ding/torch_utils/tests/test_data_helper.py index 218ce59ba7..629d081b0a 100644 --- a/ding/torch_utils/tests/test_data_helper.py +++ b/ding/torch_utils/tests/test_data_helper.py @@ -4,9 +4,10 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader +import treetensor.torch as ttorch from ding.torch_utils import CudaFetcher, to_device, to_dtype, to_tensor, to_ndarray, to_list, \ - tensor_to_list, same_shape, build_log_buffer, get_tensor_data + tensor_to_list, same_shape, build_log_buffer, get_tensor_data, get_shape0 from ding.utils import EasyTimer @@ -132,6 +133,18 @@ def test_get_tensor_data(self): with pytest.raises(TypeError): get_tensor_data(EasyTimer()) + def test_get_shape0(self): + a = { + 'a': { + 'b': torch.randn(4, 3) + }, + 'c': { + 'd': torch.randn(4) + }, + } + a = ttorch.as_tensor(a) + assert get_shape0(a) == 4 + @pytest.mark.unittest def test_log_dict(): From 0fb89c8390dd2c7835350dc832e55a63a3bcc7e9 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Thu, 9 Jun 2022 15:20:14 +0800 Subject: [PATCH 58/70] Route finish event to all processes in the cluster --- ding/framework/task.py | 4 +-- ding/framework/tests/test_parallel.py | 2 -- ding/framework/tests/test_task.py | 50 +++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/ding/framework/task.py b/ding/framework/task.py index 53e95716b0..d67c0558b4 100644 --- a/ding/framework/task.py +++ b/ding/framework/task.py @@ -330,6 +330,8 @@ def stop(self) -> None: Overview: Stop and cleanup every thing in the runtime of task. """ + if self.router.is_active: + self.emit("finish", True) if self._thread_pool: self._thread_pool.shutdown() self._event_loop.stop() @@ -472,8 +474,6 @@ def finish(self): @finish.setter def finish(self, value: bool): self._finish = value - if self.router.is_active and value is True: - self.emit("finish", value) def _wrap_event_name(self, event: str) -> str: """ diff --git a/ding/framework/tests/test_parallel.py b/ding/framework/tests/test_parallel.py index b042cb3a57..3c7c190f0c 100644 --- a/ding/framework/tests/test_parallel.py +++ b/ding/framework/tests/test_parallel.py @@ -1,9 +1,7 @@ from collections import defaultdict import pytest import time -import os from ding.framework import Parallel -from ding.utils.design_helper import SingletonMetaclass def parallel_main(): diff --git a/ding/framework/tests/test_task.py b/ding/framework/tests/test_task.py index c9d1243b6e..36a80d23f0 100644 --- a/ding/framework/tests/test_task.py +++ b/ding/framework/tests/test_task.py @@ -1,6 +1,7 @@ +import multiprocessing as mp import pytest -from threading import Lock -from time import sleep +from threading import Lock, Thread +from time import sleep, time import random from ding.framework import task, Context, Parallel @@ -331,3 +332,48 @@ def slowest(ctx): task.use(fast, lock=lock) task.run(1) assert task.ctx.result == "slowest" + + +def broadcast_finish_main(): + with task.start(): + + def tick(ctx: Context): + if task.router.node_id == 1 and ctx.total_step == 1: + task.finish = True + sleep(1) + + task.use(tick) + task.run(20) + + +def broadcast_main_target(): + Parallel.runner( + n_parallel_workers=1, protocol="tcp", address="127.0.0.1", topology="mesh", ports=50555 + )(broadcast_finish_main) + + +def broadcast_secondary_target(): + "Start two standalone processes and connect to the main process." + Parallel.runner( + n_parallel_workers=2, + protocol="tcp", + address="127.0.0.1", + topology="alone", + ports=50556, + attach_to=["tcp://127.0.0.1:50555"], + node_ids=[1, 2] + )(broadcast_finish_main) + + +@pytest.mark.unittest +@pytest.mark.timeout(10) +def test_broadcast_finish(): + start = time() + ctx = mp.get_context("spawn") + main_process = ctx.Process(target=broadcast_main_target) + secondary_process = ctx.Process(target=broadcast_secondary_target) + main_process.start() + secondary_process.start() + main_process.join() + secondary_process.join() + assert (time() - start) < 10 From 84ea5cb14a99bce95a0f780a4ab357e426635e67 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 9 Jun 2022 16:23:33 +0800 Subject: [PATCH 59/70] refactor(nyz): split dist ddp demo implementation --- ding/framework/middleware/ckpt_handler.py | 2 +- dizoo/atari/example/atari_dqn_dist_ddp.py | 136 ++++++++++++---------- 2 files changed, 78 insertions(+), 60 deletions(-) diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index d7b20f62dc..dc668a1f21 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -57,7 +57,7 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: # best eval reward so far if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value: - save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) + save_file("{}/eval.pth.tar".format(self.prefix), self.policy.eval_mode.state_dict()) self.max_eval_value = ctx.eval_value # finish diff --git a/dizoo/atari/example/atari_dqn_dist_ddp.py b/dizoo/atari/example/atari_dqn_dist_ddp.py index 0ca678a4d2..6b615abb21 100644 --- a/dizoo/atari/example/atari_dqn_dist_ddp.py +++ b/dizoo/atari/example/atari_dqn_dist_ddp.py @@ -15,75 +15,93 @@ from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config -def main(): - logging.getLogger().setLevel(logging.INFO) - main_config.exp_name = 'pong_dqn_seed0_ditask_dist_ddp' +logging.getLogger().setLevel(logging.INFO) +main_config.exp_name = 'pong_dqn_seed0_ditask_dist_ddp' + + +def learner(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) ding_init(cfg) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model, enable_field=['learn']) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + with task.start(async_mode=False, ctx=OnlineRLContext()): assert task.router.is_active, "Please execute this script with ditask! See note in the header." + logging.info("Learner running on node {}".format(task.router.node_id)) - set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - - model = DQN(**cfg.policy.model) - policy = DQNPolicy(cfg.policy, model=model) - - if 'learner' in task.router.labels: - from ding.utils import DistContext, get_rank - with DistContext(): - rank = get_rank() - logging.info("Learner running on node {}".format(task.router.node_id)) - buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) - task.use( - context_exchanger( - send_keys=["train_iter"], - recv_keys=["trajectories", "episodes", "env_step", "env_episode"], - skip_n_iter=0 - ) - ) - task.use(model_exchanger(model, is_learner=True)) - task.use(nstep_reward_enhancer(cfg)) - task.use(data_pusher(cfg, buffer_)) - task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - if rank == 0: - task.use(CkptSaver(cfg, policy, train_freq=1000)) - - elif 'evaluator' in task.router.labels: - logging.info("Evaluator running on node {}".format(task.router.node_id)) - evaluator_cfg = deepcopy(cfg.env) - evaluator_cfg.is_train = False - evaluator_env = SubprocessEnvManagerV2( - env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager - ) - task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) - task.use(model_exchanger(model, is_learner=False)) - task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) - task.use(CkptSaver(cfg, policy, save_finish=False)) - task.use(online_logger(record_train_iter=True)) - - elif 'collector' in task.router.labels: - logging.info("Collector running on node {}".format(task.router.node_id)) - collector_cfg = deepcopy(cfg.env) - collector_cfg.is_train = True - collector_env = SubprocessEnvManagerV2( - env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager - ) + from ding.utils import DistContext, get_rank + with DistContext(): + rank = get_rank() task.use( context_exchanger( - send_keys=["trajectories", "episodes", "env_step", "env_episode"], - recv_keys=["train_iter"], - skip_n_iter=1 + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 ) ) - task.use(model_exchanger(model, is_learner=False)) - task.use(eps_greedy_handler(cfg)) - task.use(StepCollector(cfg, policy.collect_mode, collector_env)) - task.use(termination_checker(max_env_step=int(1e7))) - else: - raise KeyError("invalid router labels: {}".format(task.router.labels)) + task.use(model_exchanger(model, is_learner=True)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.run() + + +def collector(): + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model, enable_field=['collect']) + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + logging.info("Collector running on node {}".format(task.router.node_id)) + + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(termination_checker(max_env_step=int(1e7))) task.run() -if __name__ == "__main__": - main() +def evaluator(): + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model, enable_field=['eval']) + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + logging.info("Evaluator running on node {}".format(task.router.node_id)) + + task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + task.use(online_logger(record_train_iter=True)) + task.run() From c350baee886a5fddb00922663356d2e6b9f83c5b Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 19 Jul 2022 18:24:41 +0800 Subject: [PATCH 60/70] feature(nyz): add rdma test demo(ci skip) --- dizoo/atari/example/atari_dqn_dist_rdma.py | 72 ++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 dizoo/atari/example/atari_dqn_dist_rdma.py diff --git a/dizoo/atari/example/atari_dqn_dist_rdma.py b/dizoo/atari/example/atari_dqn_dist_rdma.py new file mode 100644 index 0000000000..71fb1d64a1 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dist_rdma.py @@ -0,0 +1,72 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ + online_logger +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_dist_rdma' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model) + + if 'learner' in task.router.labels: + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + + elif 'collector' in task.router.labels: + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(termination_checker(max_env_step=int(1e7))) + else: + raise KeyError("invalid router labels: {}".format(task.router.labels)) + + task.run() + + +if __name__ == "__main__": + main() From 0206137c66c0262e51dc891b91d197c538d6f60b Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Fri, 9 Sep 2022 00:22:29 +0800 Subject: [PATCH 61/70] feature(xjx): new style dist version, add storage loader and model loader (#425) * Add singleton log writer * Use get_instance on writer * feature(nyz): polish atari ddp demo and add dist demo * Refactor dist version * Wrap class based middleware * Change if condition in wrapper * Only run enhancer on learner * Support new parallel mode on slurm cluster * Temp data loader * Stash commit * Init data serializer * Update dump part of code * Test StorageLoader * Turn data serializer into storage loader, add storage loader in context exchanger * Add local id and startup interval * Fix storage loader * Support treetensor * Add role on event name in context exchanger, use share_memory function on tensor * Double size buffer * Copy tensor to cpu, skip wait for context on collector and evaluator * Remove data loader middleware * Upgrade k8s parser * Add epoch timer * Dont use lb * Change tensor to numpy * Remove files when stop storage loader * Discard shared object * Ensure correct load shm memory * Add model loader * Rename model_exchanger to ModelExchanger * Add model loader benchmark * Shutdown loaders when task finish * Upgrade supervisor * Dont cleanup files when shutting down * Fix async cleanup in model loader * Check model loader on dqn * Dont use loader in dqn example * Fix style check * Fix dp * Fix github tests * Skip github ci * Fix bug in event loop * Fix enhancer tests, move router from start to __init__ * Change default ttl * Add comments Co-authored-by: niuyazhe --- ding/data/__init__.py | 4 + ..._benchmark.py => test_buffer_benchmark.py} | 0 ding/data/model_loader.py | 155 +++++++++ ding/data/shm_buffer.py | 133 ++++++++ ding/data/storage/__init__.py | 2 + ding/data/storage/file.py | 25 ++ ding/data/storage/storage.py | 16 + ding/data/storage/tests/test_storage.py | 18 ++ ding/data/storage_loader.py | 305 ++++++++++++++++++ ding/data/tests/test_model_loader.py | 74 +++++ ding/data/tests/test_shm_buffer.py | 20 ++ ding/data/tests/test_storage_loader.py | 176 ++++++++++ ding/entry/cli_ditask.py | 11 +- ding/entry/cli_parsers/k8s_parser.py | 19 +- ding/entry/cli_parsers/slurm_parser.py | 127 +++++--- .../cli_parsers/tests/test_slurm_parser.py | 28 +- ding/entry/tests/test_cli_ditask.py | 3 +- ding/envs/env_manager/env_supervisor.py | 13 +- .../env_manager/subprocess_env_manager.py | 122 +------ ding/example/__init__.py | 0 ding/example/dqn.py | 57 +++- ding/example/dqn_dist.py | 103 ------ ding/framework/__init__.py | 2 +- ding/framework/event_loop.py | 8 +- ding/framework/middleware/__init__.py | 1 + ding/framework/middleware/ckpt_handler.py | 5 + ding/framework/middleware/collector.py | 7 +- ding/framework/middleware/distributer.py | 271 ++++++++++++++++ .../middleware/functional/__init__.py | 3 +- .../functional/advantage_estimator.py | 4 +- .../middleware/functional/data_processor.py | 3 + .../middleware/functional/enhancer.py | 9 +- .../middleware/functional/evaluator.py | 4 +- .../middleware/functional/exchanger.py | 77 ----- .../middleware/functional/explorer.py | 5 +- ding/framework/middleware/functional/timer.py | 35 ++ ding/framework/middleware/learner.py | 5 + .../middleware/tests/test_ckpt_handler.py | 10 +- .../middleware/tests/test_distributer.py | 223 +++++++++++++ .../middleware/tests/test_enhancer.py | 3 +- ding/framework/parallel.py | 26 +- ding/framework/supervisor.py | 143 ++++---- ding/framework/task.py | 49 ++- ding/framework/tests/test_event_loop.py | 2 + ding/framework/tests/test_parallel.py | 17 +- ding/framework/tests/test_supervisor.py | 75 ++++- ding/framework/tests/test_task.py | 15 +- ding/framework/wrapper/step_timer.py | 16 +- 48 files changed, 1946 insertions(+), 483 deletions(-) rename ding/data/buffer/tests/{test_benchmark.py => test_buffer_benchmark.py} (100%) create mode 100644 ding/data/model_loader.py create mode 100644 ding/data/shm_buffer.py create mode 100644 ding/data/storage/__init__.py create mode 100644 ding/data/storage/file.py create mode 100644 ding/data/storage/storage.py create mode 100644 ding/data/storage/tests/test_storage.py create mode 100644 ding/data/storage_loader.py create mode 100644 ding/data/tests/test_model_loader.py create mode 100644 ding/data/tests/test_shm_buffer.py create mode 100644 ding/data/tests/test_storage_loader.py create mode 100644 ding/example/__init__.py delete mode 100644 ding/example/dqn_dist.py create mode 100644 ding/framework/middleware/distributer.py delete mode 100644 ding/framework/middleware/functional/exchanger.py create mode 100644 ding/framework/middleware/functional/timer.py create mode 100644 ding/framework/middleware/tests/test_distributer.py diff --git a/ding/data/__init__.py b/ding/data/__init__.py index 79ac868c86..b72987cac9 100644 --- a/ding/data/__init__.py +++ b/ding/data/__init__.py @@ -1,3 +1,7 @@ from torch.utils.data import Dataset, DataLoader from ding.utils.data import create_dataset, offline_data_save_type # for compatibility from .buffer import * +from .storage import * +from .storage_loader import StorageLoader, FileStorageLoader +from .shm_buffer import ShmBufferContainer, ShmBuffer +from .model_loader import ModelLoader, FileModelLoader diff --git a/ding/data/buffer/tests/test_benchmark.py b/ding/data/buffer/tests/test_buffer_benchmark.py similarity index 100% rename from ding/data/buffer/tests/test_benchmark.py rename to ding/data/buffer/tests/test_buffer_benchmark.py diff --git a/ding/data/model_loader.py b/ding/data/model_loader.py new file mode 100644 index 0000000000..cd3182897b --- /dev/null +++ b/ding/data/model_loader.py @@ -0,0 +1,155 @@ +from abc import ABC, abstractmethod +import logging +from os import path +import os +from threading import Thread +from time import sleep, time +from typing import Callable, Optional +import uuid +import torch.multiprocessing as mp + +import torch +from ding.data.storage.file import FileModelStorage +from ding.data.storage.storage import Storage +from ding.framework import Supervisor +from ding.framework.supervisor import ChildType, SendPayload + + +class ModelWorker(): + + def __init__(self, model: torch.nn.Module) -> None: + self._model = model + + def save(self, storage: Storage) -> Storage: + storage.save(self._model.state_dict()) + return storage + + +class ModelLoader(Supervisor, ABC): + + def __init__(self, model: torch.nn.Module) -> None: + """ + Overview: + Save and send models asynchronously and load them synchronously. + Arguments: + - model (:obj:`torch.nn.Module`): Torch module. + """ + if next(model.parameters()).is_cuda: + super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) + else: + super().__init__(type_=ChildType.PROCESS) + self._model = model + self._send_callback_loop = None + self._send_callbacks = {} + self._model_worker = ModelWorker(self._model) + + def start(self): + if not self._running: + self._model.share_memory() + self.register(self._model_worker) + self.start_link() + self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True) + self._send_callback_loop.start() + + def shutdown(self, timeout: Optional[float] = None) -> None: + super().shutdown(timeout) + self._send_callback_loop = None + self._send_callbacks = {} + + def _loop_send_callback(self): + while True: + payload = self.recv(ignore_err=True) + if payload.err: + logging.warning("Got error when loading data: {}".format(payload.err)) + if payload.req_id in self._send_callbacks: + del self._send_callbacks[payload.req_id] + else: + if payload.req_id in self._send_callbacks: + callback = self._send_callbacks.pop(payload.req_id) + callback(payload.data) + + def load(self, storage: Storage) -> object: + """ + Overview: + Load model synchronously. + Arguments: + - storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage. + Returns: + - object (:obj:): The loaded model. + """ + return storage.load() + + @abstractmethod + def save(self, callback: Callable) -> Storage: + """ + Overview: + Save model asynchronously. + Arguments: + - callback (:obj:`Callable`): The callback function after saving model. + Returns: + - storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned. + """ + raise NotImplementedError + + +class FileModelLoader(ModelLoader): + + def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None: + """ + Overview: + Model loader using files as storage media. + Arguments: + - model (:obj:`torch.nn.Module`): Torch module. + - dirname (:obj:`str`): The directory for saving files. + - ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \ + files that do not time out when the process is stopped are not cleaned up \ + (to avoid errors when other processes read the file), so you may need to \ + clean up the remaining files manually + """ + super().__init__(model) + self._dirname = dirname + self._ttl = ttl + self._files = [] + self._cleanup_thread = None + + def _start_cleanup(self): + """ + Overview: + Start a cleanup thread to clean up files that are taking up too much time on the disk. + """ + if self._cleanup_thread is None: + self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) + self._cleanup_thread.start() + + def shutdown(self, timeout: Optional[float] = None) -> None: + super().shutdown(timeout) + self._cleanup_thread = None + + def _loop_cleanup(self): + while True: + if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: + sleep(1) + continue + _, file_path = self._files.pop(0) + if path.exists(file_path): + os.remove(file_path) + + def save(self, callback: Callable) -> FileModelStorage: + if not self._running: + logging.warning("Please start model loader before saving model.") + return + if not path.exists(self._dirname): + os.mkdir(self._dirname) + file_path = "model_{}.pth.tar".format(uuid.uuid1()) + file_path = path.join(self._dirname, file_path) + model_storage = FileModelStorage(file_path) + payload = SendPayload(proc_id=0, method="save", args=[model_storage]) + self.send(payload) + + def clean_callback(storage: Storage): + self._files.append([time(), file_path]) + callback(storage) + + self._send_callbacks[payload.req_id] = clean_callback + self._start_cleanup() + return model_storage diff --git a/ding/data/shm_buffer.py b/ding/data/shm_buffer.py new file mode 100644 index 0000000000..b76f5d56e9 --- /dev/null +++ b/ding/data/shm_buffer.py @@ -0,0 +1,133 @@ +from typing import Any, Optional, Union, Tuple, Dict +from multiprocessing import Array +import ctypes +import numpy as np +import torch + +_NTYPE_TO_CTYPE = { + np.bool_: ctypes.c_bool, + np.uint8: ctypes.c_uint8, + np.uint16: ctypes.c_uint16, + np.uint32: ctypes.c_uint32, + np.uint64: ctypes.c_uint64, + np.int8: ctypes.c_int8, + np.int16: ctypes.c_int16, + np.int32: ctypes.c_int32, + np.int64: ctypes.c_int64, + np.float32: ctypes.c_float, + np.float64: ctypes.c_double, +} + + +class ShmBuffer(): + """ + Overview: + Shared memory buffer to store numpy array. + """ + + def __init__( + self, + dtype: Union[type, np.dtype], + shape: Tuple[int], + copy_on_get: bool = True, + ctype: Optional[type] = None + ) -> None: + """ + Overview: + Initialize the buffer. + Arguments: + - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. + - shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer. + - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. + - ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor. + """ + if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype + dtype = dtype.type + self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape))) + self.dtype = dtype + self.shape = shape + self.copy_on_get = copy_on_get + self.ctype = ctype + + def fill(self, src_arr: np.ndarray) -> None: + """ + Overview: + Fill the shared memory buffer with a numpy array. (Replace the original one.) + Arguments: + - src_arr (:obj:`np.ndarray`): array to fill the buffer. + """ + assert isinstance(src_arr, np.ndarray), type(src_arr) + # for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten + # for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten + # so we reshape dst_arr rather than flatten src_arr + dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) + np.copyto(dst_arr, src_arr) + + def get(self) -> np.ndarray: + """ + Overview: + Get the array stored in the buffer. + Return: + - data (:obj:`np.ndarray`): A copy of the data stored in the buffer. + """ + data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) + if self.copy_on_get: + data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory + if self.ctype is torch.Tensor: + data = torch.from_numpy(data) + return data + + +class ShmBufferContainer(object): + """ + Overview: + Support multiple shared memory buffers. Each key-value is name-buffer. + """ + + def __init__( + self, + dtype: Union[Dict[Any, type], type, np.dtype], + shape: Union[Dict[Any, tuple], tuple], + copy_on_get: bool = True + ) -> None: + """ + Overview: + Initialize the buffer container. + Arguments: + - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. + - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ + multiple buffers; If `tuple`, use single buffer. + - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. + """ + if isinstance(shape, dict): + self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} + elif isinstance(shape, (tuple, list)): + self._data = ShmBuffer(dtype, shape, copy_on_get) + else: + raise RuntimeError("not support shape: {}".format(shape)) + self._shape = shape + + def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None: + """ + Overview: + Fill the one or many shared memory buffer. + Arguments: + - src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer. + """ + if isinstance(self._shape, dict): + for k in self._shape.keys(): + self._data[k].fill(src_arr[k]) + elif isinstance(self._shape, (tuple, list)): + self._data.fill(src_arr) + + def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]: + """ + Overview: + Get the one or many arrays stored in the buffer. + Return: + - data (:obj:`np.ndarray`): The array(s) stored in the buffer. + """ + if isinstance(self._shape, dict): + return {k: self._data[k].get() for k in self._shape.keys()} + elif isinstance(self._shape, (tuple, list)): + return self._data.get() diff --git a/ding/data/storage/__init__.py b/ding/data/storage/__init__.py new file mode 100644 index 0000000000..962fbbbf18 --- /dev/null +++ b/ding/data/storage/__init__.py @@ -0,0 +1,2 @@ +from .storage import Storage +from .file import FileStorage, FileModelStorage diff --git a/ding/data/storage/file.py b/ding/data/storage/file.py new file mode 100644 index 0000000000..e6a89910b8 --- /dev/null +++ b/ding/data/storage/file.py @@ -0,0 +1,25 @@ +from typing import Any +from ding.data.storage import Storage +import pickle + +from ding.utils.file_helper import read_file, save_file + + +class FileStorage(Storage): + + def save(self, data: Any) -> None: + with open(self.path, "wb") as f: + pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + + def load(self) -> Any: + with open(self.path, "rb") as f: + return pickle.load(f) + + +class FileModelStorage(Storage): + + def save(self, state_dict: object) -> None: + save_file(self.path, state_dict) + + def load(self) -> object: + return read_file(self.path) diff --git a/ding/data/storage/storage.py b/ding/data/storage/storage.py new file mode 100644 index 0000000000..e6a0dae679 --- /dev/null +++ b/ding/data/storage/storage.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class Storage(ABC): + + def __init__(self, path: str) -> None: + self.path = path + + @abstractmethod + def save(self, data: Any) -> None: + raise NotImplementedError + + @abstractmethod + def load(self) -> Any: + raise NotImplementedError diff --git a/ding/data/storage/tests/test_storage.py b/ding/data/storage/tests/test_storage.py new file mode 100644 index 0000000000..8f6f1d2c47 --- /dev/null +++ b/ding/data/storage/tests/test_storage.py @@ -0,0 +1,18 @@ +import tempfile +import pytest +import os +from os import path +from ding.data.storage import FileStorage + + +@pytest.mark.unittest +def test_file_storage(): + path_ = path.join(tempfile.gettempdir(), "test_storage.txt") + try: + storage = FileStorage(path=path_) + storage.save("test") + content = storage.load() + assert content == "test" + finally: + if path.exists(path_): + os.remove(path_) diff --git a/ding/data/storage_loader.py b/ding/data/storage_loader.py new file mode 100644 index 0000000000..daf18e2d82 --- /dev/null +++ b/ding/data/storage_loader.py @@ -0,0 +1,305 @@ +from dataclasses import dataclass +import os +import torch +import numpy as np +import uuid +import treetensor.torch as ttorch +from abc import ABC, abstractmethod +from ditk import logging +from time import sleep, time +from threading import Lock, Thread +from typing import Any, Callable, Dict, List, Optional, Union +from ding.data import FileStorage, Storage +from os import path +from ding.data.shm_buffer import ShmBuffer +from ding.framework.supervisor import RecvPayload, Supervisor, ChildType, SendPayload + + +@dataclass +class ShmObject: + id_: ShmBuffer + buf: Any + + +class StorageWorker: + + def load(self, storage: Storage) -> Any: + return storage.load() + + +class StorageLoader(Supervisor, ABC): + + def __init__(self, worker_num: int = 3) -> None: + """ + Overview: + Save and send data synchronously and load them asynchronously. + Arguments: + - worker_num (:obj:`int`): Subprocess worker number. + """ + super().__init__(type_=ChildType.PROCESS) + self._load_lock = Lock() # Load (first meet) should be called one by one. + self._callback_map: Dict[str, Callable] = {} + self._shm_obj_map: Dict[int, ShmObject] = {} + self._worker_num = worker_num + self._req_count = 0 + + def shutdown(self, timeout: Optional[float] = None) -> None: + super().shutdown(timeout) + self._recv_loop = None + self._callback_map = {} + self._shm_obj_map = {} + self._req_count = 0 + + def start_link(self) -> None: + if not self._running: + super().start_link() + self._recv_loop = Thread(target=self._loop_recv, daemon=True) + self._recv_loop.start() + + @property + def _next_proc_id(self): + return self._req_count % self._worker_num + + @abstractmethod + def save(self, obj: Union[Dict, List]) -> Storage: + """ + Overview: + Save data with a storage object synchronously. + Arguments: + - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor. + Returns: + - storage (:obj:`Storage`): The storage object. + """ + raise NotImplementedError + + def load(self, storage: Storage, callback: Callable): + """ + Overview: + Load data from a storage object asynchronously. \ + This function will analysis the data structure when first meet a new data, \ + then alloc a shared memory buffer for each subprocess, these shared memory buffer \ + will be responsible for asynchronously loading data into memory. + Arguments: + - storage (:obj:`Storage`): The storage object. + - callback (:obj:`Callable`): Callback function after data loaded. + """ + with self._load_lock: + if not self._running: + self._first_meet(storage, callback) + return + + payload = SendPayload(proc_id=self._next_proc_id, method="load", args=[storage]) + self._callback_map[payload.req_id] = callback + self.send(payload) + self._req_count += 1 + + def _first_meet(self, storage: Storage, callback: Callable): + """ + Overview: + When first meet an object type, we'll load this object directly and analysis the structure, + to allocate the shared memory object and create subprocess workers. + Arguments: + - storage (:obj:`Storage`): The storage object. + - callback (:obj:`Callable`): Callback function after data loaded. + """ + obj = storage.load() + # Create three workers for each usage type. + for i in range(self._worker_num): + shm_obj = self._create_shm_buffer(obj) + self._shm_obj_map[i] = shm_obj + self.register(StorageWorker, shm_buffer=shm_obj, shm_callback=self._shm_callback) + self.start_link() + callback(obj) + + def _loop_recv(self): + while True: + payload = self.recv(ignore_err=True) + if payload.err: + logging.warning("Got error when loading data: {}".format(payload.err)) + if payload.req_id in self._callback_map: + del self._callback_map[payload.req_id] + else: + self._shm_putback(payload, self._shm_obj_map[payload.proc_id]) + if payload.req_id in self._callback_map: + callback = self._callback_map.pop(payload.req_id) + callback(payload.data) + + def _create_shm_buffer(self, obj: Union[Dict, List]) -> Optional[ShmObject]: + """ + Overview: + Create shared object (buf and callback) by walk through the data structure. + Arguments: + - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor. + Returns: + - shm_buf (:obj:`Optional[ShmObject]`): The shared memory buffer. + """ + max_level = 2 + + def to_shm(obj: Dict, level: int): + if level > max_level: + return + shm_buf = None + if isinstance(obj, Dict) or isinstance(obj, ttorch.Tensor): + shm_buf = {} + for key, val in obj.items(): + # Only numpy array can fill into shm buffer + if isinstance(val, np.ndarray): + shm_buf[key] = ShmBuffer(val.dtype, val.shape, copy_on_get=False) + elif isinstance(val, torch.Tensor): + shm_buf[key] = ShmBuffer( + val.numpy().dtype, val.numpy().shape, copy_on_get=False, ctype=torch.Tensor + ) + # Recursive parsing structure + elif isinstance(val, Dict) or isinstance(val, ttorch.Tensor) or isinstance(val, List): + buf = to_shm(val, level=level + 1) + if buf: + shm_buf[key] = buf + elif isinstance(obj, List): + # Double the size of buffer + shm_buf = [to_shm(o, level=level) for o in obj] * 2 + if all(s is None for s in shm_buf): + shm_buf = [] + return shm_buf + + shm_buf = to_shm(obj, level=0) + if shm_buf is not None: + random_id = self._random_id() + shm_buf = ShmObject(id_=ShmBuffer(random_id.dtype, random_id.shape, copy_on_get=False), buf=shm_buf) + return shm_buf + + def _random_id(self) -> np.ndarray: + return np.random.randint(1, 9e6, size=(1)) + + def _shm_callback(self, payload: RecvPayload, shm_obj: ShmObject): + """ + Overview: + Called in subprocess, put payload.data into buf. + Arguments: + - payload (:obj:`RecvPayload`): The recv payload with meta info of the data. + - shm_obj (:obj:`ShmObject`): The shm buffer. + """ + assert isinstance(payload.data, type( + shm_obj.buf + )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf)) + + # Sleep while shm object is not ready. + while shm_obj.id_.get()[0] != 0: + sleep(0.001) + + max_level = 2 + + def shm_callback(data: Union[Dict, List, ttorch.Tensor], buf: Union[Dict, List], level: int): + if level > max_level: + return + + if isinstance(buf, List): + assert isinstance(data, List), "Data ({}) and buf ({}) type not match".format(type(data), type(buf)) + elif isinstance(buf, Dict): + assert isinstance(data, ttorch.Tensor) or isinstance( + data, Dict + ), "Data ({}) and buf ({}) type not match".format(type(data), type(buf)) + + if isinstance(data, Dict) or isinstance(data, ttorch.Tensor): + for key, val in data.items(): + if isinstance(val, torch.Tensor): + val = val.numpy() + buf_val = buf.get(key) + if buf_val is None: + continue + if isinstance(buf_val, ShmBuffer) and isinstance(val, np.ndarray): + buf_val.fill(val) + data[key] = None + else: + shm_callback(val, buf_val, level=level + 1) + elif isinstance(data, List): + for i, data_ in enumerate(data): + shm_callback(data_, buf[i], level=level) + + shm_callback(payload.data, buf=shm_obj.buf, level=0) + id_ = self._random_id() + shm_obj.id_.fill(id_) + payload.extra = id_ + + def _shm_putback(self, payload: RecvPayload, shm_obj: ShmObject): + """ + Overview: + Called in main process, put buf back into payload.data. + Arguments: + - payload (:obj:`RecvPayload`): The recv payload with meta info of the data. + - shm_obj (:obj:`ShmObject`): The shm buffer. + """ + assert isinstance(payload.data, type( + shm_obj.buf + )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf)) + + assert shm_obj.id_.get()[0] == payload.extra[0], "Shm object and payload do not match ({} - {}).".format( + shm_obj.id_.get()[0], payload.extra[0] + ) + + def shm_putback(data: Union[Dict, List], buf: Union[Dict, List]): + if isinstance(data, Dict) or isinstance(data, ttorch.Tensor): + for key, val in data.items(): + buf_val = buf.get(key) + if buf_val is None: + continue + if val is None and isinstance(buf_val, ShmBuffer): + data[key] = buf[key].get() + else: + shm_putback(val, buf_val) + elif isinstance(data, List): + for i, data_ in enumerate(data): + shm_putback(data_, buf[i]) + + shm_putback(payload.data, buf=shm_obj.buf) + shm_obj.id_.fill(np.array([0])) + + +class FileStorageLoader(StorageLoader): + + def __init__(self, dirname: str, ttl: int = 20, worker_num: int = 3) -> None: + """ + Overview: + Dump and load object with file storage. + Arguments: + - dirname (:obj:`str`): The directory to save files. + - ttl (:obj:`str`): Maximum time to keep a file, after which it will be deleted. + - worker_num (:obj:`int`): Number of subprocess worker loaders. + """ + super().__init__(worker_num) + self._dirname = dirname + self._files = [] + self._cleanup_thread = None + self._ttl = ttl # # Delete files created 10 minutes ago. + + def save(self, obj: Union[Dict, List]) -> FileStorage: + if not path.exists(self._dirname): + os.mkdir(self._dirname) + filename = "{}.pkl".format(uuid.uuid1()) + full_path = path.join(self._dirname, filename) + f = FileStorage(full_path) + f.save(obj) + self._files.append([time(), f.path]) + self._start_cleanup() + return f + + def _start_cleanup(self): + """ + Overview: + Start a cleanup thread to clean up files that are taking up too much time on the disk. + """ + if self._cleanup_thread is None: + self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) + self._cleanup_thread.start() + + def shutdown(self, timeout: Optional[float] = None) -> None: + super().shutdown(timeout) + self._cleanup_thread = None + + def _loop_cleanup(self): + while True: + if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: + sleep(1) + continue + _, file_path = self._files.pop(0) + if path.exists(file_path): + os.remove(file_path) diff --git a/ding/data/tests/test_model_loader.py b/ding/data/tests/test_model_loader.py new file mode 100644 index 0000000000..caf8c07186 --- /dev/null +++ b/ding/data/tests/test_model_loader.py @@ -0,0 +1,74 @@ +import shutil +import tempfile +from time import sleep, time +import pytest +from ding.data.model_loader import FileModelLoader +from ding.data.storage.file import FileModelStorage +from ding.model import DQN +from ding.config import compile_config +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config +from os import path +import torch + + +@pytest.mark.tmp # gitlab ci and local test pass, github always fail +def test_model_loader(): + tempdir = path.join(tempfile.gettempdir(), "test_model_loader") + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + model = DQN(**cfg.policy.model) + loader = FileModelLoader(model=model, dirname=tempdir, ttl=1) + try: + loader.start() + model_storage = None + + def save_model(storage): + nonlocal model_storage + model_storage = storage + + start = time() + loader.save(save_model) + save_time = time() - start + print("Save time: {:.4f}s".format(save_time)) + assert save_time < 0.1 + sleep(0.5) + assert isinstance(model_storage, FileModelStorage) + assert len(loader._files) > 0 + + state_dict = loader.load(model_storage) + model.load_state_dict(state_dict) + + sleep(2) + assert not path.exists(model_storage.path) + assert len(loader._files) == 0 + finally: + if path.exists(tempdir): + shutil.rmtree(tempdir) + + +@pytest.mark.benchmark +def test_model_loader_benchmark(): + model = torch.nn.Sequential(torch.nn.Linear(1024, 1024), torch.nn.Linear(1024, 100)) # 40MB + tempdir = path.join(tempfile.gettempdir(), "test_model_loader") + loader = FileModelLoader(model=model, dirname=tempdir) + + try: + loader.start() + count = 0 + + def send_callback(_): + nonlocal count + count += 1 + + start = time() + for _ in range(5): + loader.save(send_callback) + sleep(0.2) + + while count < 5: + sleep(0.001) + + assert time() - start < 1.2 + finally: + if path.exists(tempdir): + shutil.rmtree(tempdir) + loader.shutdown() diff --git a/ding/data/tests/test_shm_buffer.py b/ding/data/tests/test_shm_buffer.py new file mode 100644 index 0000000000..04334b4799 --- /dev/null +++ b/ding/data/tests/test_shm_buffer.py @@ -0,0 +1,20 @@ +import pytest +import numpy as np +import timeit +from ding.data.shm_buffer import ShmBuffer +import multiprocessing as mp + + +def subprocess(shm_buf): + data = np.random.rand(1024, 1024).astype(np.float32) + res = timeit.repeat(lambda: shm_buf.fill(data), repeat=5, number=1000) + print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res))) + + +@pytest.mark.benchmark +def test_shm_buffer(): + data = np.random.rand(1024, 1024).astype(np.float32) + shm_buf = ShmBuffer(data.dtype, data.shape, copy_on_get=False) + proc = mp.Process(target=subprocess, args=[shm_buf]) + proc.start() + proc.join() diff --git a/ding/data/tests/test_storage_loader.py b/ding/data/tests/test_storage_loader.py new file mode 100644 index 0000000000..5ab07acd73 --- /dev/null +++ b/ding/data/tests/test_storage_loader.py @@ -0,0 +1,176 @@ +import os +import timeit +import pytest +import tempfile +import shutil +import numpy as np +import torch +import treetensor.torch as ttorch +from ding.data.shm_buffer import ShmBuffer +from ding.data.storage_loader import FileStorageLoader +from time import sleep, time +from os import path +from ding.framework.supervisor import RecvPayload + + +@pytest.mark.tmp # gitlab ci and local test pass, github always fail +def test_file_storage_loader(): + tempdir = path.join(tempfile.gettempdir(), "test_storage_loader") + loader = FileStorageLoader(dirname=tempdir) + try: + total_num = 200 + storages = [] + for i in range(10): + # 21MB + data = [ + { + "s": "abc", + "obs": np.random.rand(4, 84, 84).astype(np.float32), + # "next_obs": np.random.rand(4, 84, 84).astype(np.float32), + # "obs": torch.rand(4, 84, 84, dtype=torch.float32), + "next_obs": torch.rand(4, 84, 84, dtype=torch.float32) + } for _ in range(96) + ] + storage = loader.save(data) + storages.append(storage) + + start = time() + for i in range(total_num): + storage = storages[i % 10] + data = storage.load() + origin_time_cost = time() - start + print("Load time cost: {:.4f}s".format(origin_time_cost)) + + call_times = 0 + + def callback(data): + assert data[0]['obs'] is not None + nonlocal call_times + call_times += 1 + + # First initialize shared memory is very slow, discard this time cost. + start = time() + loader._first_meet(storage=storages[0], callback=callback) + print("Initialize shared memory time: {:.4f}s".format(time() - start)) + + start = time() + for i in range(1, total_num): + storage = storages[i % 10] + loader.load(storage, callback) + + while True: + if call_times == total_num: + break + sleep(0.01) + new_time_cost = time() - start + print("Loader time cost: {:.4f}s".format(new_time_cost)) + + assert new_time_cost < origin_time_cost + finally: + if path.exists(tempdir): + shutil.rmtree(tempdir) + loader.shutdown() + + +@pytest.mark.unittest +def test_file_storage_loader_cleanup(): + tempdir = path.join(tempfile.gettempdir(), "test_storage_loader") + loader = FileStorageLoader(dirname=tempdir, ttl=1) + try: + storages = [] + for _ in range(4): + data = np.random.rand(4, 84, 84).astype(np.float32) + storage = loader.save(data) + storages.append(storage) + sleep(0.5) + assert len(os.listdir(tempdir)) < 4 + finally: + if path.exists(tempdir): + shutil.rmtree(tempdir) + loader.shutdown() + + +@pytest.mark.unittest +def test_shared_object(): + loader = FileStorageLoader(dirname="") + + # ========== Test array ========== + obj = [{"obs": np.random.rand(100, 100)} for _ in range(10)] + shm_obj = loader._create_shm_buffer(obj) + assert len(shm_obj.buf) == len(obj) * 2 + assert isinstance(shm_obj.buf[0]["obs"], ShmBuffer) + + # Callback + payload = RecvPayload(proc_id=0, data=obj) + loader._shm_callback(payload=payload, shm_obj=shm_obj) + assert len(payload.data) == 10 + assert [d["obs"] is None for d in payload.data] + + # ========== Putback ========== + loader._shm_putback(payload=payload, shm_obj=shm_obj) + obj = payload.data + assert len(obj) == 10 + for o in obj: + assert isinstance(o["obs"], np.ndarray) + assert o["obs"].shape == (100, 100) + + # ========== Test dict ========== + obj = {"obs": torch.rand(100, 100, dtype=torch.float32)} + shm_obj = loader._create_shm_buffer(obj) + assert isinstance(shm_obj.buf["obs"], ShmBuffer) + + payload = RecvPayload(proc_id=0, data=obj) + loader._shm_callback(payload=payload, shm_obj=shm_obj) + assert payload.data["obs"] is None + + loader._shm_putback(payload=payload, shm_obj=shm_obj) + assert isinstance(payload.data["obs"], torch.Tensor) + assert payload.data["obs"].shape == (100, 100) + + # ========== Test treetensor ========== + obj = {"trajectories": [ttorch.as_tensor({"obs": torch.rand(10, 10, dtype=torch.float32)}) for _ in range(10)]} + shm_obj = loader._create_shm_buffer(obj) + + payload = RecvPayload(proc_id=0, data=obj) + loader._shm_callback(payload=payload, shm_obj=shm_obj) + assert len(payload.data["trajectories"]) == 10 + for traj in payload.data["trajectories"]: + assert traj["obs"] is None + + loader._shm_putback(payload=payload, shm_obj=shm_obj) + for traj in payload.data["trajectories"]: + assert isinstance(traj["obs"], torch.Tensor) + assert traj["obs"].shape == (10, 10) + + +@pytest.mark.benchmark +def test_shared_object_benchmark(): + loader = FileStorageLoader(dirname="") + # ========== Test treetensor ========== + obj = { + "env_step": 0, + "trajectories": [ + ttorch.as_tensor( + { + "done": False, + "reward": torch.tensor([1, 0], dtype=torch.int32), + "obs": torch.rand(4, 84, 84, dtype=torch.float32), + "next_obs": torch.rand(4, 84, 84, dtype=torch.float32), + "action": torch.tensor([1], dtype=torch.int32), + "collect_train_iter": torch.tensor([1], dtype=torch.int32), + "env_data_id": torch.tensor([1], dtype=torch.int32), + } + ) for _ in range(10) + ] + } + buf = loader._create_shm_buffer(obj) + payload = RecvPayload(proc_id=0, data=obj) + loader._shm_callback(payload=payload, shm_obj=buf) + + def stmt(): + payload.extra = buf.id_.get() + loader._shm_putback(payload=payload, shm_obj=buf) + + res = timeit.repeat(stmt, repeat=5, number=1000) + print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res))) + assert np.mean(res) < 1 diff --git a/ding/entry/cli_ditask.py b/ding/entry/cli_ditask.py index f05bd5257a..443fe1a6b6 100644 --- a/ding/entry/cli_ditask.py +++ b/ding/entry/cli_ditask.py @@ -43,8 +43,7 @@ def print_version(ctx: Context, param: Option, value: bool) -> None: @click.option( "--ports", type=str, - default="50515", - help="The port addresses that the tasks listen to, e.g. 50515,50516, default: 50515" + help="The port addresses that the tasks listen to, e.g. 50515,50516, default: k8s, local: 50515, slurm: 15151" ) @click.option("--attach-to", type=str, help="The addresses to connect to.") @click.option("--address", type=str, help="The address to listen to (without port).") @@ -62,6 +61,7 @@ def print_version(ctx: Context, param: Option, value: bool) -> None: @click.option("--redis-host", type=str, help="Redis host.") @click.option("--redis-port", type=int, help="Redis port.") @click.option("-m", "--main", type=str, help="Main function of entry module.") +@click.option("--startup-interval", type=int, default=1, help="Start up interval between each task.") @click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP") def cli_ditask(*args, **kwargs): return _cli_ditask(*args, **kwargs) @@ -87,7 +87,7 @@ def _parse_platform_args(platform: str, platform_spec: str, all_args: dict): parsed_args = PLATFORM_PARSERS[platform](platform_spec, **all_args) except Exception as e: click.echo("error when parse platform spec configure: {}".format(e)) - exit(1) + raise e return parsed_args @@ -106,6 +106,7 @@ def _cli_ditask( mq_type: str, redis_host: str, redis_port: int, + startup_interval: int, local_rank: int = 0, platform: str = None, platform_spec: str = None, @@ -130,6 +131,7 @@ def _cli_ditask( mod = importlib.import_module(mod_name) main_func = getattr(mod, func_name) # Parse arguments + ports = ports or 50515 if not isinstance(ports, int): ports = ports.split(",") ports = list(map(lambda i: int(i), ports)) @@ -154,5 +156,6 @@ def _cli_ditask( node_ids=node_ids, mq_type=mq_type, redis_host=redis_host, - redis_port=redis_port + redis_port=redis_port, + startup_interval=startup_interval )(main_func) diff --git a/ding/entry/cli_parsers/k8s_parser.py b/ding/entry/cli_parsers/k8s_parser.py index 3767ef2879..6f2b0aebe7 100644 --- a/ding/entry/cli_parsers/k8s_parser.py +++ b/ding/entry/cli_parsers/k8s_parser.py @@ -1,11 +1,12 @@ import os import numpy as np -from typing import List, Optional +from time import sleep +from typing import Dict, List, Optional class K8SParser(): - def __init__(self, platform_spec: Optional[str] = None, **kwargs) -> None: + def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None: """ Overview: Should only set global cluster properties @@ -14,9 +15,9 @@ def __init__(self, platform_spec: Optional[str] = None, **kwargs) -> None: self.nodelist = self._parse_node_list() self.ntasks = len(self.nodelist) self.platform_spec = platform_spec - self.parallel_workers = kwargs.get("parallel_workers", 1) - self.topology = kwargs.get("topology", "alone") - self.ports = kwargs.get("ports", 50515) + self.parallel_workers = kwargs.get("parallel_workers") or 1 + self.topology = kwargs.get("topology") or "alone" + self.ports = int(kwargs.get("ports") or 50515) self.tasks = {} def parse(self) -> dict: @@ -49,13 +50,13 @@ def _get_task(self, procid: int) -> dict: else: task = {} if "ports" not in task: - task["ports"] = self._get_ports() + task["ports"] = self.kwargs.get("ports") or self._get_ports() if "address" not in task: - task["address"] = self._get_address(procid) + task["address"] = self.kwargs.get("address") or self._get_address(procid) if "node_ids" not in task: - task["node_ids"] = self._get_node_id(procid) + task["node_ids"] = self.kwargs.get("node_ids") or self._get_node_id(procid) - task["attach_to"] = self._get_attach_to(procid, task.get("attach_to")) + task["attach_to"] = self.kwargs.get("attach_to") or self._get_attach_to(procid, task.get("attach_to")) task["topology"] = self.topology task["parallel_workers"] = self.parallel_workers diff --git a/ding/entry/cli_parsers/slurm_parser.py b/ding/entry/cli_parsers/slurm_parser.py index 3a335eb758..c46716438b 100644 --- a/ding/entry/cli_parsers/slurm_parser.py +++ b/ding/entry/cli_parsers/slurm_parser.py @@ -1,30 +1,55 @@ import os import re -from typing import List +from time import sleep +import numpy as np +from typing import Any, Dict, List, Optional class SlurmParser(): - def __init__(self, platform_spec: str, **kwargs) -> None: + def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None: """ Overview: Should only set global cluster properties """ self.kwargs = kwargs self.ntasks = int(os.environ["SLURM_NTASKS"]) - self.tasks = platform_spec["tasks"] + self.platform_spec = platform_spec + self.tasks = {} self.ntasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"]) self.nodelist = self._parse_node_list() + self.ports = int(kwargs.get("ports") or 15151) + self.parallel_workers = kwargs.get("parallel_workers") or 1 + self.topology = kwargs.get("topology") or "alone" def parse(self) -> dict: - assert len(self.tasks) == self.ntasks procid = int(os.environ["SLURM_PROCID"]) - nodename = os.environ["SLURMD_NODENAME"] - task = self._get_node_args(procid) + task = self._get_task(procid) # Validation - assert task["address"] == nodename + assert task["address"] == os.environ["SLURMD_NODENAME"] return {**self.kwargs, **task} + def _get_task(self, procid: int) -> Dict[str, Any]: + if procid in self.tasks: + return self.tasks.get(procid) + if self.platform_spec: + task = self.platform_spec["tasks"][procid] + else: + task = {} + if "ports" not in task: + task["ports"] = self._get_ports(procid) + if "address" not in task: + task["address"] = self._get_address(procid) + if "node_ids" not in task: + task["node_ids"] = self._get_node_id(procid) + + task["attach_to"] = self._get_attach_to(procid, task.get("attach_to")) + task["topology"] = self.topology + task["parallel_workers"] = self.parallel_workers + + self.tasks[procid] = task + return task + def _parse_node_list(self) -> List[str]: nodelist = os.environ["SLURM_NODELIST"] result = re.match(r"(.*)?\[(.*)\]$", nodelist) @@ -40,58 +65,86 @@ def _parse_node_list(self) -> List[str]: nodelist.append(prefix + tail) elif isinstance(nodelist, str): nodelist = [nodelist] + if self.ntasks_per_node > 1: + expand_nodelist = [] # Expand node for each task + for node in nodelist: + for _ in range(self.ntasks_per_node): + expand_nodelist.append(node) + nodelist = expand_nodelist return nodelist - def _get_node_args(self, procid: int) -> dict: - """ - Overview: - Complete node properties, use environment vars in list instead of on current node. - For example, if you want to set nodename in this function, please derive it from SLURM_NODELIST, - the variable from SLURMD_NODENAME should only be used in validation. - """ - task = self.tasks[procid] - if "attach_to" in task: - task["attach_to"] = self._get_attach_to(task["attach_to"]) - if "address" not in task: - task["address"] = self._get_address(procid) - if "ports" not in task: - task["ports"] = self._get_ports(procid) - if "node_ids" not in task: - task["node_ids"] = procid - return task + def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str: + if attach_to: + attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")] + elif procid == 0: + attach_to = [] + else: + if self.topology == "mesh": + prev_tasks = [self._get_task(i) for i in range(procid)] + attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks] + attach_to = list(np.concatenate(attach_to)) + elif self.topology == "star": + head_task = self._get_task(0) + attach_to = self._get_attach_to_from_task(head_task) + else: + attach_to = [] - def _get_attach_to(self, attach_to: str) -> str: - attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")] return ",".join(attach_to) def _get_attach_to_part(self, attach_part: str) -> str: + """ + Overview: + Parse each part of attach_to. + Arguments: + - attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0 + Returns + - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 + """ if not attach_part.startswith("$node."): return attach_part attach_node_id = int(attach_part[6:]) - attach_node = self._get_node_args(self._get_procid_from_nodeid(attach_node_id)) - return "tcp://{}:{}".format(attach_node["address"], attach_node["ports"]) + attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id)) + return self._get_tcp_link(attach_task["address"], attach_task["ports"]) + + def _get_attach_to_from_task(self, task: dict) -> List[str]: + """ + Overview: + Get attach_to list from task, note that parallel_workers will affact the connected processes. + Arguments: + - task (:obj:`dict`): The task object. + Returns + - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 + """ + port = task.get("ports") + address = task.get("address") + ports = [int(port) + i for i in range(self.parallel_workers)] + attach_to = [self._get_tcp_link(address, port) for port in ports] + return attach_to def _get_procid_from_nodeid(self, nodeid: int) -> int: procid = None - for i, task in enumerate(self.tasks): - if task.get("node_ids") == nodeid: - procid = i - break - elif nodeid == i: + for i in range(self.ntasks): + task = self._get_task(i) + if task["node_ids"] == nodeid: procid = i break if procid is None: raise Exception("Can not find procid from nodeid: {}".format(nodeid)) return procid - def _get_ports(self, procid: int) -> List[int]: - ports = 15151 + procid % self.ntasks_per_node - return ports + def _get_ports(self, procid) -> int: + return self.ports + (procid % self.ntasks_per_node) * self.parallel_workers def _get_address(self, procid: int) -> str: - address = self.nodelist[procid // self.ntasks_per_node] + address = self.nodelist[procid] return address + def _get_node_id(self, procid: int) -> int: + return procid * self.parallel_workers + + def _get_tcp_link(self, address: str, port: int) -> str: + return "tcp://{}:{}".format(address, port) + def slurm_parser(platform_spec: str, **kwargs) -> dict: return SlurmParser(platform_spec, **kwargs).parse() diff --git a/ding/entry/cli_parsers/tests/test_slurm_parser.py b/ding/entry/cli_parsers/tests/test_slurm_parser.py index f56efdd663..9b817ba48a 100644 --- a/ding/entry/cli_parsers/tests/test_slurm_parser.py +++ b/ding/entry/cli_parsers/tests/test_slurm_parser.py @@ -10,12 +10,7 @@ def set_slurm_env(): os.environ["SLURM_NTASKS"] = '6' # Parameter n,Process count / Task count os.environ["SLURM_NTASKS_PER_NODE"] = '3' # Parameter ntasks-per-node,process count of each node os.environ["SLURM_NODELIST"] = 'SH-IDC1-10-5-38-[190,215]' # All the nodes - os.environ["SLURM_SRUN_COMM_PORT"] = '42932' # Available ports - os.environ["SLURM_TOPOLOGY_ADDR"] = 'SH-IDC1-10-5-38-215' # Name of current node - os.environ["SLURM_NODEID"] = '1' # Node order,start from 0 os.environ["SLURM_PROCID"] = '3' # Proc order,start from 0,the read proc order may be different from nominal order - os.environ["SLURM_LOCALID"] = '0' # Proc order on current node, smaller or equal than ntasks-per-node - 1 - os.environ["SLURM_GTIDS"] = '2,3' # All the proc ids on current node os.environ["SLURMD_NODENAME"] = 'SH-IDC1-10-5-38-215' # Name of current node yield @@ -23,12 +18,7 @@ def set_slurm_env(): del os.environ["SLURM_NTASKS"] del os.environ["SLURM_NTASKS_PER_NODE"] del os.environ["SLURM_NODELIST"] - del os.environ["SLURM_SRUN_COMM_PORT"] - del os.environ["SLURM_TOPOLOGY_ADDR"] - del os.environ["SLURM_NODEID"] del os.environ["SLURM_PROCID"] - del os.environ["SLURM_LOCALID"] - del os.environ["SLURM_GTIDS"] del os.environ["SLURMD_NODENAME"] @@ -73,8 +63,22 @@ def test_slurm_parser(): "tcp://SH-IDC1-10-5-38-190:15152," +\ "tcp://SH-IDC1-10-5-38-190:15153" + # Test without platform_spec + all_args = slurm_parser(None, topology="mesh", mq_type="nng") + assert all_args["address"] == "SH-IDC1-10-5-38-215" + assert all_args["node_ids"] == 3 + assert all_args["parallel_workers"] == 1 + assert all_args[ + "attach_to" + ] == "tcp://SH-IDC1-10-5-38-190:15151," +\ + "tcp://SH-IDC1-10-5-38-190:15152," +\ + "tcp://SH-IDC1-10-5-38-190:15153" + # Test _parse_node_list sp = SlurmParser(platform_spec) os.environ["SLURM_NODELIST"] = 'SH-IDC1-10-5-[38-40]' - nodelist = sp._parse_node_list() - assert nodelist == ['SH-IDC1-10-5-38', 'SH-IDC1-10-5-39', 'SH-IDC1-10-5-40'] + nodelist = sp._parse_node_list() # Nodes * parallel_workers + assert nodelist == [ + 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-39', 'SH-IDC1-10-5-39', + 'SH-IDC1-10-5-39', 'SH-IDC1-10-5-40', 'SH-IDC1-10-5-40', 'SH-IDC1-10-5-40' + ] diff --git a/ding/entry/tests/test_cli_ditask.py b/ding/entry/tests/test_cli_ditask.py index 66cd37e906..6bb64e5e6e 100644 --- a/ding/entry/tests/test_cli_ditask.py +++ b/ding/entry/tests/test_cli_ditask.py @@ -25,7 +25,8 @@ def test_cli_ditask(): "node_ids": 0, "mq_type": "nng", "redis_host": "", - "redis_port": "" + "redis_port": "", + "startup_interval": 1 } os.environ["DI_NODES"] = '127.0.0.1' os.environ["DI_RANK"] = '0' diff --git a/ding/envs/env_manager/env_supervisor.py b/ding/envs/env_manager/env_supervisor.py index b3e1c17fcf..ec5e29beab 100644 --- a/ding/envs/env_manager/env_supervisor.py +++ b/ding/envs/env_manager/env_supervisor.py @@ -5,10 +5,10 @@ import gym from ding.framework import Supervisor from typing import TYPE_CHECKING, Any, List, Union, Dict, Optional, Callable -from ding.framework.supervisor import ChildType, RecvPayload, SendPayload, SharedObject +from ding.framework.supervisor import ChildType, RecvPayload, SendPayload from ding.utils import make_key_as_identifier from ditk import logging -from ding.envs.env_manager.subprocess_env_manager import ShmBufferContainer +from ding.data import ShmBufferContainer import enum import treetensor.numpy as tnp import numbers @@ -106,9 +106,7 @@ def __init__( for env_id in range(len(self._env_fn)) } for env_init in env_fn: - self.register( - env_init, shared_object=SharedObject(buf=self._obs_buffers, callback=self._shm_callback) - ) + self.register(env_init, shm_buffer=self._obs_buffers, shm_callback=self._shm_callback) else: for env_init in env_fn: self.register(env_init) @@ -136,6 +134,11 @@ def _init_states(self): self._last_called = defaultdict(lambda: {"step": math.inf, "reset": math.inf}) def _shm_callback(self, payload: RecvPayload, obs_buffers: Any): + """ + Overview: + This method will be called in child worker, so we can put large data into shared memory + and replace the original payload data to none, then reduce the serialization/deserialization cost. + """ if payload.method == "reset" and payload.data is not None: obs_buffers[payload.proc_id].fill(payload.data) payload.data = None diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py index 19498aae33..fefa053ae1 100644 --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -1,5 +1,5 @@ from typing import Any, Union, List, Tuple, Dict, Callable, Optional -from multiprocessing import Pipe, connection, get_context, Array +from multiprocessing import connection, get_context from collections import namedtuple from ditk import logging import platform @@ -8,33 +8,19 @@ import gym import traceback import torch -import ctypes import pickle import cloudpickle import numpy as np import treetensor.numpy as tnp from easydict import EasyDict from types import MethodType +from ding.data import ShmBufferContainer, ShmBuffer from ding.envs.env import BaseEnvTimestep from ding.utils import PropagatingThread, LockContextType, LockContext, ENV_MANAGER_REGISTRY, make_key_as_identifier, \ remove_illegal_item from .base_env_manager import BaseEnvManager, EnvState, timeout_wrapper -_NTYPE_TO_CTYPE = { - np.bool_: ctypes.c_bool, - np.uint8: ctypes.c_uint8, - np.uint16: ctypes.c_uint16, - np.uint32: ctypes.c_uint32, - np.uint64: ctypes.c_uint64, - np.int8: ctypes.c_int8, - np.int16: ctypes.c_int16, - np.int32: ctypes.c_int32, - np.int64: ctypes.c_int64, - np.float32: ctypes.c_float, - np.float64: ctypes.c_double, -} - def is_abnormal_timestep(timestep: namedtuple) -> bool: if isinstance(timestep.info, dict): @@ -45,110 +31,6 @@ def is_abnormal_timestep(timestep: namedtuple) -> bool: raise TypeError("invalid env timestep type: {}".format(type(timestep.info))) -class ShmBuffer(): - """ - Overview: - Shared memory buffer to store numpy array. - """ - - def __init__(self, dtype: Union[type, np.dtype], shape: Tuple[int], copy_on_get: bool = True) -> None: - """ - Overview: - Initialize the buffer. - Arguments: - - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. - - shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer. - - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. - """ - if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype - dtype = dtype.type - self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape))) - self.dtype = dtype - self.shape = shape - self.copy_on_get = copy_on_get - - def fill(self, src_arr: np.ndarray) -> None: - """ - Overview: - Fill the shared memory buffer with a numpy array. (Replace the original one.) - Arguments: - - src_arr (:obj:`np.ndarray`): array to fill the buffer. - """ - assert isinstance(src_arr, np.ndarray), type(src_arr) - # for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten - # for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten - # so we reshape dst_arr rather than flatten src_arr - dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) - np.copyto(dst_arr, src_arr) - - def get(self) -> np.ndarray: - """ - Overview: - Get the array stored in the buffer. - Return: - - data (:obj:`np.ndarray`): A copy of the data stored in the buffer. - """ - data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) - if self.copy_on_get: - data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory - return data - - -class ShmBufferContainer(object): - """ - Overview: - Support multiple shared memory buffers. Each key-value is name-buffer. - """ - - def __init__( - self, - dtype: Union[Dict[Any, type], type, np.dtype], - shape: Union[Dict[Any, tuple], tuple], - copy_on_get: bool = True - ) -> None: - """ - Overview: - Initialize the buffer container. - Arguments: - - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. - - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ - multiple buffers; If `tuple`, use single buffer. - - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. - """ - if isinstance(shape, dict): - self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} - elif isinstance(shape, (tuple, list)): - self._data = ShmBuffer(dtype, shape, copy_on_get) - else: - raise RuntimeError("not support shape: {}".format(shape)) - self._shape = shape - - def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None: - """ - Overview: - Fill the one or many shared memory buffer. - Arguments: - - src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer. - """ - if isinstance(self._shape, dict): - for k in self._shape.keys(): - self._data[k].fill(src_arr[k]) - elif isinstance(self._shape, (tuple, list)): - self._data.fill(src_arr) - - def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]: - """ - Overview: - Get the one or many arrays stored in the buffer. - Return: - - data (:obj:`np.ndarray`): The array(s) stored in the buffer. - """ - if isinstance(self._shape, dict): - return {k: self._data[k].get() for k in self._shape.keys()} - elif isinstance(self._shape, (tuple, list)): - return self._data.get() - - class CloudPickleWrapper: """ Overview: diff --git a/ding/example/__init__.py b/ding/example/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ding/example/dqn.py b/ding/example/dqn.py index 88964a59b9..d343ee4ea7 100644 --- a/ding/example/dqn.py +++ b/ding/example/dqn.py @@ -1,5 +1,43 @@ +""" +# Example of DQN pipeline + +Use the pipeline on a single process: + +> python ding/example/dqn.py + +Use the pipeline on multiple processes: + +We surpose there are N processes (workers) = 1 learner + 1 evaluator + (N-2) actors + +## First Example —— Execute on one machine with multi processes. + +Execute 4 processes with 1 learner + 1 evaluator + 2 actors +Remember to keep them connected by mesh to ensure that they can exchange information with each other. + +> ditask --package . --main ding.example.dqn.main --parallel-workers 4 --topology mesh + +## Second Example —— Execute on multiple machines. + +1. Execute 1 learner + 1 evaluator on one machine. + +> ditask --package . --main ding.example.dqn.main --parallel-workers 2 --topology mesh --node-ids 0 --ports 50515 + +2. Execute 2 collectors on another machine. (Suppose the ip of the first machine is 127.0.0.1). + Here we use `alone` topology instead of `mesh` because the collectors do not need communicate with each other. + Remember the `node_ids` cannot be duplicated with the learner, evaluator processes. + And remember to set the `ports` (should not conflict with others) and `attach_to` parameters. + The value of the `attach_to` parameter should be obtained from the log of the + process started earlier (e.g. 'NNG listen on tcp://10.0.0.4:50515'). + +> ditask --package . --main ding.example.dqn.main --parallel-workers 2 --topology alone --node-ids 2 \ + --ports 50517 --attach-to tcp://10.0.0.4:50515,tcp://127.0.0.1:50516 + +3. You can repeat step 2 to start more collectors on other machines. +""" import gym from ditk import logging +from ding.data.model_loader import FileModelLoader +from ding.data.storage_loader import FileStorageLoader from ding.model import DQN from ding.policy import DQNPolicy from ding.envs import DingEnvWrapper, BaseEnvManagerV2 @@ -8,7 +46,7 @@ from ding.framework import task from ding.framework.context import OnlineRLContext from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ - eps_greedy_handler, CkptSaver + eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger from ding.utils import set_pkg_seed from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config @@ -32,12 +70,29 @@ def main(): buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) policy = DQNPolicy(cfg.policy, model=model) + # Consider the case with multiple processes + if task.router.is_active: + # You can use labels to distinguish between workers with different roles, + # here we use node_id to distinguish. + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.EVALUATOR) + else: + task.add_role(task.role.COLLECTOR) + + # Sync their context and model between each worker. + task.use(ContextExchanger(skip_n_iter=1)) + task.use(ModelExchanger(model)) + + # Here is the part of single process pipeline. task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(eps_greedy_handler(cfg)) task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=100)) + task.run() diff --git a/ding/example/dqn_dist.py b/ding/example/dqn_dist.py deleted file mode 100644 index 0fd1e09111..0000000000 --- a/ding/example/dqn_dist.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -The distributed version of DQN pipeline. -With N workers = 1 learner + 1 evaluator + (N-2) actors - -# First Example —— Execute on one machine with multi processes. -Execute 4 processes with 1 learner + 1 evaluator + 2 actors -Remember to keep them connected by mesh to ensure that they can exchange information with each other. - -> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 4 --topology mesh - -# Second Example —— Execute on multiple machines. -1. Execute 1 learner + 1 evaluator on one machine. - -> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology mesh --node-ids 0 --ports 50515 - -2. Execute 2 collectors on another machine. (Suppose the ip of the first machine is 127.0.0.1). - Here we use `alone` topology instead of `mesh` because the collectors do not need communicate with each other. - Remember the `node_ids` cannot be duplicated with the learner, evaluator processes. - And remember to set the `ports` (should not conflict with others) and `attach_to` parameters. - The value of the `attach_to` parameter should be obtained from the log of the - process started earlier (e.g. 'NNG listen on tcp://10.0.0.4:50515'). - -> ditask --package . --main ding.example.dqn_dist.main --parallel-workers 2 --topology alone --node-ids 2 \ - --ports 50517 --attach-to tcp://10.0.0.4:50515,tcp://127.0.0.1:50516 - -3. You can repeat step 2 to start more collectors on other machines. -""" -import gym -from ditk import logging -from ding.model import DQN -from ding.policy import DQNPolicy -from ding.envs import DingEnvWrapper, BaseEnvManagerV2 -from ding.data import DequeBuffer -from ding.config import compile_config -from ding.framework import task -from ding.framework.context import OnlineRLContext -from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ - eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger -from ding.utils import set_pkg_seed -from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config - - -def main(): - logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, auto=True) - # cfg.env.stop_value = 99999999 # Don't stop - with task.start(async_mode=False, ctx=OnlineRLContext()): - assert task.router.is_active, "Please execute this script with ditask! See note in the header." - - set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - - model = DQN(**cfg.policy.model) - buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) - policy = DQNPolicy(cfg.policy, model=model) - - if task.router.node_id == 0: # Learner - logging.info("Learner running on node {}".format(task.router.node_id)) - buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) - task.use( - context_exchanger( - send_keys=["train_iter"], - recv_keys=["trajectories", "episodes", "env_step", "env_episode"], - skip_n_iter=0 - ) - ) - task.use(model_exchanger(model, is_learner=True)) - task.use(data_pusher(cfg, buffer_)) - task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) - - elif task.router.node_id == 1: # Evaluator - logging.info("Evaluator running on node {}".format(task.router.node_id)) - evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], - cfg=cfg.env.manager - ) - task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) - task.use(model_exchanger(model, is_learner=False)) - task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) - task.use(CkptSaver(cfg, policy, save_finish=False)) - - else: # Collectors - logging.info("Collector running on node {}".format(task.router.node_id)) - collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], - cfg=cfg.env.manager - ) - task.use( - context_exchanger( - send_keys=["trajectories", "episodes", "env_step", "env_episode"], - recv_keys=["train_iter"], - skip_n_iter=1 - ) - ) - task.use(model_exchanger(model, is_learner=False)) - task.use(eps_greedy_handler(cfg)) - task.use(StepCollector(cfg, policy.collect_mode, collector_env)) - - task.run() - - -if __name__ == "__main__": - main() diff --git a/ding/framework/__init__.py b/ding/framework/__init__.py index 274d6f2364..72c23d0475 100644 --- a/ding/framework/__init__.py +++ b/ding/framework/__init__.py @@ -1,5 +1,5 @@ from .context import Context, OnlineRLContext, OfflineRLContext -from .task import Task, task +from .task import Task, task, VoidMiddleware from .parallel import Parallel from .event_loop import EventLoop from .supervisor import Supervisor diff --git a/ding/framework/event_loop.py b/ding/framework/event_loop.py index b5f58720a1..6641d07adb 100644 --- a/ding/framework/event_loop.py +++ b/ding/framework/event_loop.py @@ -1,6 +1,7 @@ from collections import defaultdict from typing import Callable, Optional from concurrent.futures import ThreadPoolExecutor +from copy import copy import fnmatch from ditk import logging @@ -35,7 +36,10 @@ def off(self, event: str, fn: Optional[Callable] = None) -> None: """ for e in fnmatch.filter(self._listeners.keys(), event): if fn: - self._listeners[e].remove(fn) + try: + self._listeners[e].remove(fn) + except: + pass else: self._listeners[e] = [] @@ -79,7 +83,7 @@ def _trigger(self, event: str, *args, **kwargs) -> None: if event not in self._listeners: logging.debug("Event {} is not registered in the callbacks of {}!".format(event, self._name)) return - for fn in self._listeners[event]: + for fn in copy(self._listeners[event]): try: fn(*args, **kwargs) except Exception as e: diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index a2c428932c..558d9affa3 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -2,3 +2,4 @@ from .collector import StepCollector, EpisodeCollector from .learner import OffPolicyLearner, HERLearner from .ckpt_handler import CkptSaver +from .distributer import ContextExchanger, ModelExchanger diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index a287199eb1..b47bcc6ee4 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -17,6 +17,11 @@ class CkptSaver: The class used to save checkpoint data. """ + def __new__(cls, *args, **kwargs): + if task.router.is_active and not (task.has_role(task.role.LEARNER) or task.has_role(task.role.EVALUATOR)): + return task.void() + return super(CkptSaver, cls).__new__(cls) + def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = None, save_finish: bool = True): """ Overview: diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index fa70a00766..6660025b33 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING from easydict import EasyDict from ding.policy import get_random_policy @@ -17,6 +17,11 @@ class StepCollector: process. Use the `__call__` method to execute the whole collection process. """ + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.COLLECTOR): + return task.void() + return super(StepCollector, cls).__new__(cls) + def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: """ Arguments: diff --git a/ding/framework/middleware/distributer.py b/ding/framework/middleware/distributer.py new file mode 100644 index 0000000000..815a78bca9 --- /dev/null +++ b/ding/framework/middleware/distributer.py @@ -0,0 +1,271 @@ +from time import sleep, time +from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union +from ditk import logging +from ding.framework import task +from ding.data import StorageLoader, Storage, ModelLoader +if TYPE_CHECKING: + from ding.framework.context import Context + from torch.nn import Module + + +class ContextExchanger: + + def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] = None) -> None: + """ + Overview: + Exchange context between processes, + support properties: trajectories, episodes, env_step, env_episode, train_iter + Arguments: + - skip_n_iter (:obj:`int`): For collectors, it may be necessary to skip waiting \ + for the first n iterations to collect data for the learner to learn. This parameter \ + will not work on learner. + - storage_loader (:obj:`Optional[StorageLoader]`): Turn data into storage class to reduce \ + the network overhead. + """ + if not task.router.is_active: + raise RuntimeError("ContextHandler should be used in parallel mode!") + self._state = {} + self._event_name = "context_exchanger_{role}" + self._skip_n_iter = skip_n_iter + self._storage_loader = storage_loader + for role in task.role: # Only subscribe to other roles + if not task.has_role(role): + task.on(self._event_name.format(role=role), self.put) + if storage_loader: + task.once("finish", lambda _: storage_loader.shutdown()) + + def __new__(cls, *args, **kwargs): + if not task.router.is_active: + return task.void() + + if len(task.roles) == 0: + logging.warning("The task does not have any roles defined, the ContextExchanger will not work.") + return task.void() + + if len(task.roles) > 1: + logging.warning( + "Use multiple roles in one exchanger may lead to unexpected result, please check your code." + ) + + return super(ContextExchanger, cls).__new__(cls) + + def __call__(self, ctx: "Context"): + self.merge(ctx) + yield + payload = self.fetch(ctx) + if payload: + if self._storage_loader and task.has_role(task.role.COLLECTOR): + payload = self._storage_loader.save(payload) + for role in task.roles: + task.emit(self._event_name.format(role=role), payload, only_remote=True) + + def __del__(self): + if self._storage_loader: + self._storage_loader.shutdown() + + def put(self, payload: Union[Dict, Storage]): + """ + Overview: + Get attributes from ctx on the callback of event. + Each attribute should have a standalone put handler, which named `_put_{key}` + """ + + def callback(payload: Dict): + for key, item in payload.items(): + fn_name = "_put_{}".format(key) + if hasattr(self, fn_name): + getattr(self, fn_name)(item) + else: + logging.warning("Receive unexpected key ({}) in context exchanger".format(key)) + + if isinstance(payload, Storage): + assert self._storage_loader is not None, "Storage loader is not defined when data is a storage object." + self._storage_loader.load(payload, callback) + else: + callback(payload) + + def fetch(self, ctx: "Context") -> Dict[str, Any]: + """ + Overview: + Fetch attributes from ctx before emit them to the event bus. + Each attribute should have a standalone fetch handler, which named `_fetch_{key}` + """ + payload = {} + for key, item in ctx.items(): + fn_name = "_fetch_{}".format(key) + if hasattr(self, fn_name): + value = getattr(self, fn_name)(item) + if value is not None: + payload[key] = value + return payload + + def merge(self, ctx: "Context"): + if task.has_role(task.role.LEARNER): + # Learner should always wait for trajs. + # TODO: Automaticlly wait based on properties, not roles. + while len(self._state) == 0: + sleep(0.01) + elif ctx.total_step >= self._skip_n_iter: + start = time() + while len(self._state) == 0: + if time() - start > 60: + logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id)) + break + sleep(0.01) + + for k, v in self._state.items(): + ctx[k] = v + self._state = {} + + # Handle each attibute of context + def _put_trajectories(self, traj: List[Any]): + if not task.has_role(task.role.LEARNER): + return + if "trajectories" not in self._state: + self._state["trajectories"] = [] + self._state["trajectories"].extend(traj) + + def _fetch_trajectories(self, traj: List[Any]): + if task.has_role(task.role.COLLECTOR): + return traj + + def _put_episodes(self, episodes: List[Any]): + if not task.has_role(task.role.LEARNER): + return + if "episodes" not in self._state: + self._state["episodes"] = [] + self._state["episodes"].extend(episodes) + + def _fetch_episodes(self, episodes: List[Any]): + if task.has_role(task.role.COLLECTOR): + return episodes + + def _put_trajectory_end_idx(self, trajectory_end_idx: List[str]): + if not task.has_role(task.role.LEARNER): + return + if "trajectory_end_idx" not in self._state: + self._state["trajectory_end_idx"] = [] + self._state["trajectory_end_idx"].extend(trajectory_end_idx) + + def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]): + if task.has_role(task.role.COLLECTOR): + return trajectory_end_idx + + def _put_env_step(self, env_step: int): + if not task.has_role(task.role.COLLECTOR): + self._state["env_step"] = env_step + + def _fetch_env_step(self, env_step: int): + if task.has_role(task.role.COLLECTOR): + return env_step + + def _put_env_episode(self, env_episode: int): + if not task.has_role(task.role.COLLECTOR): + self._state["env_episode"] = env_episode + + def _fetch_env_episode(self, env_episode: int): + if task.has_role(task.role.COLLECTOR): + return env_episode + + def _put_train_iter(self, train_iter: int): + if not task.has_role(task.role.LEARNER): + self._state["train_iter"] = train_iter + + def _fetch_train_iter(self, train_iter: int): + if task.has_role(task.role.LEARNER): + return train_iter + + +class ModelExchanger: + + def __init__(self, model: "Module", model_loader: Optional[ModelLoader] = None) -> None: + """ + Overview: + Exchange model between processes, only the learner will send the model, + otherwise the model will only be received. + If you are using a shared model on a single host, there is no need to use this middleware. + Arguments: + - model (:obj:`torch.nn.Module`): Pytorch module. + - model_loader (:obj:`ModelLoader`): Encode model in subprocess. + """ + self._model = model + self._model_loader = model_loader + self._event_name = "model_exchanger" + self._state_dict_cache: Optional[Union[object, Storage]] = None + self._is_learner = task.has_role(task.role.LEARNER) + if not self._is_learner: + task.on(self._event_name, self._cache_state_dict) + if model_loader: + task.once("finish", lambda _: model_loader.shutdown()) + + def _cache_state_dict(self, state_dict: Union[object, Storage]): + self._state_dict_cache = state_dict + + def __new__(cls, *args, **kwargs): + if not task.router.is_active: + return task.void() + + if len(task.roles) == 0: + logging.warning("The task does not have any roles defined, the ModelExchanger will not work.") + return task.void() + + if len(task.roles) > 1: + logging.warning( + "Use multiple roles in one exchanger may lead to unexpected result, please check your code." + ) + + return super(ModelExchanger, cls).__new__(cls) + + def __call__(self, ctx: "Context") -> Any: + if self._model_loader: + self._model_loader.start() + + if not self._is_learner: + if ctx.total_step != 0: # Skip first iteration + self._update_model() + else: + yield + self._send_model() + + def _update_model(self): + start = time() + while True: + if task.finish: + return + if time() - start > 60: + logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id)) + break + if self._state_dict_cache is None: + sleep(0.01) + else: + if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None: + try: + self._model.load_state_dict(self._model_loader.load(self._state_dict_cache)) + self._state_dict_cache = None + break + except FileNotFoundError as e: + logging.warning( + "Model file has been deleted on node {}, maybe you can increase the ttl.".format( + task.router.node_id + ) + ) + self._state_dict_cache = None + continue + else: + self._model.load_state_dict(self._state_dict_cache) + self._state_dict_cache = None + break + + def _send_model(self): + if self._model_loader: + self._model_loader.save(self._send_callback) + else: + task.emit(self._event_name, self._model.state_dict(), only_remote=True) + + def _send_callback(self, storage: Storage): + if task.running: + task.emit(self._event_name, storage, only_remote=True) + + def __del__(self): + if self._model_loader: + self._model_loader.shutdown() diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index d8d3525ea4..42f24ab22b 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -5,10 +5,11 @@ from .evaluator import interaction_evaluator from .termination_checker import termination_checker, ddp_termination_checker from .logger import online_logger, offline_logger -from .exchanger import context_exchanger, model_exchanger from .ctx_helper import final_ctx_saver # algorithm from .explorer import eps_greedy_handler, eps_greedy_masker from .advantage_estimator import gae_estimator from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer + +from .timer import epoch_timer diff --git a/ding/framework/middleware/functional/advantage_estimator.py b/ding/framework/middleware/functional/advantage_estimator.py index 01f9cb64ea..fa3e7fad39 100644 --- a/ding/framework/middleware/functional/advantage_estimator.py +++ b/ding/framework/middleware/functional/advantage_estimator.py @@ -43,7 +43,7 @@ def _gae(ctx: "OnlineRLContext"): data = ctx.trajectories # list data = ttorch_collate(data) with torch.no_grad(): - if cfg.policy.cuda: + if cfg.policy.get("cuda", False): data = data.cuda() value = model.forward(data.obs, mode='compute_critic')['value'] next_value = model.forward(data.next_obs, mode='compute_critic')['value'] @@ -56,7 +56,7 @@ def _gae(ctx: "OnlineRLContext"): # done is bool type when acquired from env.step data_ = gae_data(data.value, next_value, data.reward, data.done.float(), traj_flag.float()) data.adv = gae(data_, cfg.policy.collect.discount_factor, cfg.policy.collect.gae_lambda) - if cfg.policy.cuda: + if cfg.policy.get("cuda", False): data = data.cpu() if buffer_ is None: ctx.train_data = data diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 0ec60de3b0..06679ef9f8 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -4,6 +4,7 @@ import torch from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type from ding.data.buffer.middleware import PriorityExperienceReplay +from ding.framework import task if TYPE_CHECKING: from ding.framework import OnlineRLContext, OfflineRLContext @@ -17,6 +18,8 @@ def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = N - cfg (:obj:`EasyDict`): Config. - buffer (:obj:`Buffer`): Buffer to push the data in. """ + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() def _push(ctx: "OnlineRLContext"): """ diff --git a/ding/framework/middleware/functional/enhancer.py b/ding/framework/middleware/functional/enhancer.py index 4bd9ad45e4..b983945791 100644 --- a/ding/framework/middleware/functional/enhancer.py +++ b/ding/framework/middleware/functional/enhancer.py @@ -2,7 +2,7 @@ from easydict import EasyDict from ditk import logging import torch -from ding.policy import Policy +from ding.framework import task if TYPE_CHECKING: from ding.framework import OnlineRLContext from ding.reward_model import BaseRewardModel, HerRewardModel @@ -17,6 +17,8 @@ def reward_estimator(cfg: EasyDict, reward_model: "BaseRewardModel") -> Callable - cfg (:obj:`EasyDict`): Config. - reward_model (:obj:`BaseRewardModel`): Reward model. """ + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() def _enhance(ctx: "OnlineRLContext"): """ @@ -40,6 +42,8 @@ def her_data_enhancer(cfg: EasyDict, buffer_: "Buffer", her_reward_model: "HerRe - her_reward_model (:obj:`HerRewardModel`): Hindsight Experience Replay (HER) model \ which is used to process episodes. """ + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() def _fetch_and_enhance(ctx: "OnlineRLContext"): """ @@ -69,6 +73,9 @@ def _fetch_and_enhance(ctx: "OnlineRLContext"): def nstep_reward_enhancer(cfg: EasyDict) -> Callable: + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() + def _enhance(ctx: "OnlineRLContext"): nstep = cfg.policy.nstep gamma = cfg.policy.discount_factor diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 35294a449c..8264f443a9 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, Any, List, Union +from typing import Callable, Any, List, Union from abc import ABC, abstractmethod from collections import deque from ditk import logging @@ -154,6 +154,8 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> - policy (:obj:`Policy`): The policy to be evaluated. - env (:obj:`BaseEnvManager`): The env for the evaluation. """ + if task.router.is_active and not task.has_role(task.role.EVALUATOR): + return task.void() env.seed(cfg.seed, dynamic_seed=False) diff --git a/ding/framework/middleware/functional/exchanger.py b/ding/framework/middleware/functional/exchanger.py deleted file mode 100644 index 81e0051bf4..0000000000 --- a/ding/framework/middleware/functional/exchanger.py +++ /dev/null @@ -1,77 +0,0 @@ -from ditk import logging -from queue import Empty -from typing import TYPE_CHECKING, List, Dict -from ding.framework import task -from ding.utils.data.structure.lifo_deque import LifoDeque -if TYPE_CHECKING: - from ding.framework.context import Context - from torch.nn import Module - - -def context_exchanger(send_keys: List[str] = None, recv_keys: List[str] = None, skip_n_iter: int = 0): - """ - Overview: - Send data from context in the backward stage. - Buffer received data and wait if not get any data. - Arguments: - - send_keys (:obj:`List[str]`): Keys need to be sent. - - recv_keys (:obj:`List[str]`): Keys need to be received. - - skip_n_iter (:obj:`int`): Whether to skip the first N round of waiting, - e.g. collecting data without waiting for a new model in the first N round, - while training a model that needs to wait for data in the first round. - """ - event_name = "context_exchanger" - - bufferd_payloads = LifoDeque(maxsize=100) - task.on(event_name, lambda payload: bufferd_payloads.put(payload)) - - def _context_exchanger(ctx: "Context"): - if recv_keys: - if ctx.total_step >= skip_n_iter: - payload: Dict = bufferd_payloads.get() - for key in recv_keys: - value = payload.get(key) - if value: - ctx[key] = value - - if send_keys: - yield - payload = {} - for key in send_keys: - payload[key] = ctx.get(key) - if payload: - task.emit(event_name, payload, only_remote=True) - - return _context_exchanger - - -def model_exchanger(model: "Module", is_learner: bool = False): - """ - Overview: - Exchange model between processes, only the learner will send the model, - otherwise the model will only be received. - If you are using a shared model on a single host, there is no need to use this middleware. - Arguments: - - model (:obj:`torch.nn.Module`): Pytorch module. - - is_learner (:obj:`bool`): Whether use this middleware as learner or not. - """ - event_name = "model_exchanger" - bufferd_state_dict = LifoDeque(maxsize=1) - - if not is_learner: - task.on(event_name, lambda state_dict: bufferd_state_dict.put(state_dict)) - - def _model_exchanger(ctx: "Context"): - if not is_learner: - if ctx.total_step != 0: # Skip first iteration - try: - state_dict = bufferd_state_dict.get(timeout=5) - model.load_state_dict(state_dict) - except Empty: - logging.warning("Timeout when waiting for new model!") - - if is_learner: - yield - task.emit(event_name, model.state_dict(), only_remote=True) - - return _model_exchanger diff --git a/ding/framework/middleware/functional/explorer.py b/ding/framework/middleware/functional/explorer.py index 4c7364004d..45aa9bd24a 100644 --- a/ding/framework/middleware/functional/explorer.py +++ b/ding/framework/middleware/functional/explorer.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING, Callable from easydict import EasyDict from ding.rl_utils import get_epsilon_greedy_fn +from ding.framework import task if TYPE_CHECKING: from ding.framework import OnlineRLContext @@ -13,6 +14,8 @@ def eps_greedy_handler(cfg: EasyDict) -> Callable: Arguments: - cfg (:obj:`EasyDict`): Config. """ + if task.router.is_active and not task.has_role(task.role.COLLECTOR): + return task.void() eps_cfg = cfg.policy.other.eps handle = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) diff --git a/ding/framework/middleware/functional/timer.py b/ding/framework/middleware/functional/timer.py new file mode 100644 index 0000000000..db8a2c0056 --- /dev/null +++ b/ding/framework/middleware/functional/timer.py @@ -0,0 +1,35 @@ +import numpy as np +from collections import deque +from ditk import logging +from time import time + +from ding.framework import task +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ding.framework.context import Context + + +def epoch_timer(print_per: int = 1, smooth_window: int = 10): + """ + Overview: + Print time cost of each epoch. + Arguments: + - print_per (:obj:`int`): Print each N epoch. + - smooth_window (:obj:`int`): The window size to smooth the mean. + """ + records = deque(maxlen=print_per * smooth_window) + + def _epoch_timer(ctx: "Context"): + start = time() + yield + time_cost = time() - start + records.append(time_cost) + if ctx.total_step % print_per == 0: + logging.info( + "[Epoch Timer][Node:{:>2}]: Cost: {:.2f}ms, Mean: {:.2f}ms".format( + task.router.node_id or 0, time_cost * 1000, + np.mean(records) * 1000 + ) + ) + + return _epoch_timer diff --git a/ding/framework/middleware/learner.py b/ding/framework/middleware/learner.py index 4d60f117bd..91184a7b9b 100644 --- a/ding/framework/middleware/learner.py +++ b/ding/framework/middleware/learner.py @@ -17,6 +17,11 @@ class OffPolicyLearner: the `__call__` method to execute the whole learning process. """ + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() + return super(OffPolicyLearner, cls).__new__(cls) + def __init__( self, cfg: EasyDict, diff --git a/ding/framework/middleware/tests/test_ckpt_handler.py b/ding/framework/middleware/tests/test_ckpt_handler.py index 9a22ffece0..56a3dbf0d4 100644 --- a/ding/framework/middleware/tests/test_ckpt_handler.py +++ b/ding/framework/middleware/tests/test_ckpt_handler.py @@ -11,7 +11,7 @@ from unittest.mock import Mock, patch from ding.framework import task -from ding.utils import save_file +from ding.policy.base_policy import Policy class TheModelClass(nn.Module): @@ -22,10 +22,14 @@ def state_dict(self): class MockPolicy(Mock): - def __init__(self, model) -> None: - super(MockPolicy, self).__init__() + def __init__(self, model, **kwargs) -> None: + super(MockPolicy, self).__init__(model) self.learn_mode = model + @property + def eval_mode(self): + return EasyDict({"state_dict": lambda: {}}) + @pytest.mark.unittest def test_ckpt_saver(): diff --git a/ding/framework/middleware/tests/test_distributer.py b/ding/framework/middleware/tests/test_distributer.py new file mode 100644 index 0000000000..c7c323bac9 --- /dev/null +++ b/ding/framework/middleware/tests/test_distributer.py @@ -0,0 +1,223 @@ +import shutil +from time import sleep +import pytest +import numpy as np +import tempfile + +import torch +from ding.data.model_loader import FileModelLoader +from ding.data.storage_loader import FileStorageLoader +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger +from ding.framework.parallel import Parallel +from ding.utils.default_helper import set_pkg_seed +from os import path + + +def context_exchanger_main(): + with task.start(ctx=OnlineRLContext()): + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.COLLECTOR) + + task.use(ContextExchanger(skip_n_iter=1)) + + if task.has_role(task.role.LEARNER): + + def learner_context(ctx: OnlineRLContext): + assert len(ctx.trajectories) == 2 + assert len(ctx.trajectory_end_idx) == 4 + assert len(ctx.episodes) == 8 + assert ctx.env_step > 0 + assert ctx.env_episode > 0 + yield + ctx.train_iter += 1 + + task.use(learner_context) + elif task.has_role(task.role.COLLECTOR): + + def collector_context(ctx: OnlineRLContext): + if ctx.total_step > 0: + assert ctx.train_iter > 0 + yield + ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] + ctx.trajectory_end_idx = [1 for _ in range(4)] + ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] + ctx.env_step += 1 + ctx.env_episode += 1 + + task.use(collector_context) + + task.run(max_step=3) + + +@pytest.mark.unittest +def test_context_exchanger(): + Parallel.runner(n_parallel_workers=2)(context_exchanger_main) + + +def context_exchanger_with_storage_loader_main(): + with task.start(ctx=OnlineRLContext()): + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.COLLECTOR) + + tempdir = path.join(tempfile.gettempdir(), "test_storage_loader") + storage_loader = FileStorageLoader(dirname=tempdir) + try: + task.use(ContextExchanger(skip_n_iter=1, storage_loader=storage_loader)) + + if task.has_role(task.role.LEARNER): + + def learner_context(ctx: OnlineRLContext): + assert len(ctx.trajectories) == 2 + assert len(ctx.trajectory_end_idx) == 4 + assert len(ctx.episodes) == 8 + assert ctx.env_step > 0 + assert ctx.env_episode > 0 + yield + ctx.train_iter += 1 + + task.use(learner_context) + elif task.has_role(task.role.COLLECTOR): + + def collector_context(ctx: OnlineRLContext): + if ctx.total_step > 0: + assert ctx.train_iter > 0 + yield + ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] + ctx.trajectory_end_idx = [1 for _ in range(4)] + ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] + ctx.env_step += 1 + ctx.env_episode += 1 + + task.use(collector_context) + + task.run(max_step=3) + finally: + storage_loader.shutdown() + sleep(1) + if path.exists(tempdir): + shutil.rmtree(tempdir) + + +@pytest.mark.unittest +def test_context_exchanger_with_storage_loader(): + Parallel.runner(n_parallel_workers=2)(context_exchanger_with_storage_loader_main) + + +class MockPolicy: + + def __init__(self) -> None: + self._model = self._get_model(10, 10) + + def _get_model(self, X_shape, y_shape) -> torch.nn.Module: + return torch.nn.Sequential( + torch.nn.Linear(X_shape, 24), torch.nn.ReLU(), torch.nn.Linear(24, 24), torch.nn.ReLU(), + torch.nn.Linear(24, y_shape) + ) + + def train(self, X, y): + loss_fn = torch.nn.MSELoss(reduction="mean") + optimizer = torch.optim.Adam(self._model.parameters(), lr=0.01) + y_pred = self._model(X) + loss = loss_fn(y_pred, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def predict(self, X): + with torch.no_grad(): + return self._model(X) + + +def model_exchanger_main(): + with task.start(ctx=OnlineRLContext()): + set_pkg_seed(0, use_cuda=False) + policy = MockPolicy() + X = torch.rand(10) + y = torch.rand(10) + + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + else: + task.add_role(task.role.COLLECTOR) + + task.use(ModelExchanger(policy._model)) + + if task.has_role(task.role.LEARNER): + + def train(ctx): + policy.train(X, y) + sleep(0.3) + + task.use(train) + else: + y_pred1 = policy.predict(X) + + def pred(ctx): + if ctx.total_step > 0: + y_pred2 = policy.predict(X) + # Ensure model is upgraded + assert any(y_pred1 != y_pred2) + sleep(0.3) + + task.use(pred) + + task.run(2) + + +@pytest.mark.unittest +def test_model_exchanger(): + Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main) + + +def model_exchanger_main_with_model_loader(): + with task.start(ctx=OnlineRLContext()): + set_pkg_seed(0, use_cuda=False) + policy = MockPolicy() + X = torch.rand(10) + y = torch.rand(10) + + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + else: + task.add_role(task.role.COLLECTOR) + + tempdir = path.join(tempfile.gettempdir(), "test_model_loader") + model_loader = FileModelLoader(policy._model, dirname=tempdir) + task.use(ModelExchanger(policy._model, model_loader=model_loader)) + + try: + if task.has_role(task.role.LEARNER): + + def train(ctx): + policy.train(X, y) + sleep(0.3) + + task.use(train) + else: + y_pred1 = policy.predict(X) + + def pred(ctx): + if ctx.total_step > 0: + y_pred2 = policy.predict(X) + # Ensure model is upgraded + assert any(y_pred1 != y_pred2) + sleep(0.3) + + task.use(pred) + task.run(2) + finally: + model_loader.shutdown() + sleep(0.3) + if path.exists(tempdir): + shutil.rmtree(tempdir) + + +@pytest.mark.unittest +def test_model_exchanger_with_model_loader(): + Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main_with_model_loader) diff --git a/ding/framework/middleware/tests/test_enhancer.py b/ding/framework/middleware/tests/test_enhancer.py index a1765c031d..10d34b264f 100644 --- a/ding/framework/middleware/tests/test_enhancer.py +++ b/ding/framework/middleware/tests/test_enhancer.py @@ -2,8 +2,7 @@ import torch from ding.framework import OnlineRLContext from ding.data.buffer import DequeBuffer -from easydict import EasyDict -from typing import Any, List, Dict, Optional +from typing import Any import numpy as np import copy from ding.framework.middleware.functional.enhancer import reward_estimator, her_data_enhancer diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index a4c8b8d0b7..38e343e495 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -3,8 +3,8 @@ import random import time import traceback -from mpire.pool import WorkerPool import pickle +from mpire.pool import WorkerPool from ditk import logging import tempfile import socket @@ -27,6 +27,7 @@ def __init__(self) -> None: self._listener = None self.is_active = False self.node_id = None + self.local_id = None self.labels = set() self._event_loop = EventLoop("parallel_{}".format(id(self))) self._retries = 0 # Retries in auto recovery @@ -34,19 +35,24 @@ def __init__(self) -> None: def _run( self, node_id: int, + local_id: int, n_parallel_workers: int, labels: Optional[Set[str]] = None, auto_recover: bool = False, max_retries: int = float("inf"), mq_type: str = "nng", + startup_interval: int = 1, **kwargs ) -> None: self.node_id = node_id + self.local_id = local_id + self.startup_interval = startup_interval self.n_parallel_workers = n_parallel_workers self.labels = labels or set() self.auto_recover = auto_recover self.max_retries = max_retries self._mq = MQ_REGISTRY.get(mq_type)(**kwargs) + time.sleep(self.local_id * self.startup_interval) self._listener = Thread(target=self.listen, name="mq_listener", daemon=True) self._listener.start() @@ -65,7 +71,8 @@ def runner( auto_recover: bool = False, max_retries: int = float("inf"), redis_host: Optional[str] = None, - redis_port: Optional[int] = None + redis_port: Optional[int] = None, + startup_interval: int = 1 ) -> Callable: """ Overview: @@ -87,6 +94,7 @@ def runner( - max_retries (:obj:`int`): Max retries for auto recover. - redis_host (:obj:`str`): Redis server host. - redis_port (:obj:`int`): Redis server port. + - startup_interval (:obj:`int`): Start up interval between each task. Returns: - _runner (:obj:`Callable`): The wrapper function for main. """ @@ -104,7 +112,10 @@ def _runner(main_process: Callable, *args, **kwargs) -> None: - main_process (:obj:`Callable`): The main function, your program start from here. """ runner_params = args_parsers[mq_type](**all_args) - params_group = [[runner_kwargs, (main_process, args, kwargs)] for runner_kwargs in runner_params] + params_group = [] + for i, runner_kwargs in enumerate(runner_params): + runner_kwargs["local_id"] = i + params_group.append([runner_kwargs, (main_process, args, kwargs)]) if n_parallel_workers == 1: cls._subprocess_runner(*params_group[0]) @@ -181,6 +192,7 @@ def _subprocess_runner(cls, runner_kwargs: dict, main_params: Tuple[Union[List, - runner_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for runner. - main_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for main function. """ + logging.getLogger().setLevel(logging.INFO) main_process, args, kwargs = main_params with Parallel() as router: @@ -322,7 +334,7 @@ def emit(self, event: str, *args, **kwargs) -> None: if self.is_active: payload = {"a": args, "k": kwargs} try: - data = pickle.dumps(payload, protocol=-1) + data = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) except AttributeError as e: logging.error("Arguments are not pickable! Event: {}, Args: {}".format(event, args)) raise e @@ -353,12 +365,12 @@ def get_ip(cls): try: # doesn't even have to be reachable s.connect(('10.255.255.255', 1)) - IP = s.getsockname()[0] + ip = s.getsockname()[0] except Exception: - IP = '127.0.0.1' + ip = '127.0.0.1' finally: s.close() - return IP + return ip def __enter__(self) -> "Parallel": return self diff --git a/ding/framework/supervisor.py b/ding/framework/supervisor.py index 22f67177f0..7d385c12c6 100644 --- a/ding/framework/supervisor.py +++ b/ding/framework/supervisor.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod -import multiprocessing as mp +import functools +import torch.multiprocessing as mp +from multiprocessing.context import BaseContext import threading import queue import platform @@ -12,6 +14,13 @@ from enum import Enum +@functools.lru_cache(maxsize=1) +def get_mp_ctx() -> BaseContext: + context = 'spawn' if platform.system().lower() == 'windows' else 'fork' + mp_ctx = mp.get_context(context) + return mp_ctx + + @dataclass class SendPayload: proc_id: int @@ -29,6 +38,7 @@ class RecvPayload: method: str = None data: Any = None err: Exception = None + extra: Any = None class ReserveMethod(Enum): @@ -41,27 +51,16 @@ class ChildType(Enum): THREAD = "thread" -@dataclass -class SharedObject: - buf: Any - callback: Callable - - class Child(ABC): """ Abstract class of child process/thread. """ - def __init__( - self, proc_id: int, init: Callable, *args, shared_object: Optional[SharedObject] = None, **kwargs - ) -> None: + def __init__(self, proc_id: int, init: Union[Callable, object], **kwargs) -> None: self._proc_id = proc_id self._init = init - self._args = args - self._kwargs = kwargs self._recv_queue = None self._send_queue = None - self._shared_object = shared_object @abstractmethod def start(self, recv_queue: Union[mp.Queue, queue.Queue]): @@ -82,15 +81,17 @@ def send(self, payload: SendPayload): def _target( self, proc_id: int, - init: Callable, - args: List, - kwargs: Dict[str, Any], + init: Union[Callable, object], send_queue: Union[mp.Queue, queue.Queue], recv_queue: Union[mp.Queue, queue.Queue], - shared_object: Optional[SharedObject] = None + shm_buffer: Optional[Any] = None, + shm_callback: Optional[Callable] = None ): send_payload = SendPayload(proc_id=proc_id) - child_ins = init(*args, **kwargs) + if isinstance(init, Callable): + child_ins = init() + else: + child_ins = init while True: try: send_payload: SendPayload = send_queue.get() @@ -103,8 +104,8 @@ def _target( recv_payload = RecvPayload( proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, data=data ) - if shared_object: - shared_object.callback(recv_payload, shared_object.buf) + if shm_callback is not None and shm_buffer is not None: + shm_callback(recv_payload, shm_buffer) recv_queue.put(recv_payload) except Exception as e: logging.warning(traceback.format_exc()) @@ -121,27 +122,35 @@ def __del__(self): class ChildProcess(Child): def __init__( - self, proc_id: int, init: Callable, *args, shared_object: Optional[SharedObject] = None, **kwargs + self, + proc_id: int, + init: Union[Callable, object], + shm_buffer: Optional[Any] = None, + shm_callback: Optional[Callable] = None, + mp_ctx: Optional[BaseContext] = None, + **kwargs ) -> None: - super().__init__(proc_id, init, *args, shared_object=shared_object, **kwargs) + super().__init__(proc_id, init, **kwargs) self._proc = None + self._mp_ctx = mp_ctx + self._shm_buffer = shm_buffer + self._shm_callback = shm_callback def start(self, recv_queue: mp.Queue): - self._recv_queue = recv_queue - context = 'spawn' if platform.system().lower() == 'windows' else 'fork' - ctx = mp.get_context(context) - self._send_queue = ctx.Queue() - proc = ctx.Process( - target=self._target, - args=( - self._proc_id, self._init, self._args, self._kwargs, self._send_queue, self._recv_queue, - self._shared_object - ), - name="supervisor_child_{}_{}".format(self._proc_id, time.time()), - daemon=True - ) - proc.start() - self._proc = proc + if self._proc is None: + self._recv_queue = recv_queue + ctx = self._mp_ctx or get_mp_ctx() + self._send_queue = ctx.Queue() + proc = ctx.Process( + target=self._target, + args=( + self._proc_id, self._init, self._send_queue, self._recv_queue, self._shm_buffer, self._shm_callback + ), + name="supervisor_child_{}_{}".format(self._proc_id, time.time()), + daemon=True + ) + proc.start() + self._proc = proc def shutdown(self, timeout: Optional[float] = None): if self._proc: @@ -156,28 +165,30 @@ def shutdown(self, timeout: Optional[float] = None): self._send_queue = None def send(self, payload: SendPayload): + if self._send_queue is None: + logging.warning("Child worker has been terminated or not started.") + return self._send_queue.put(payload) class ChildThread(Child): - def __init__( - self, proc_id: int, init: Callable, *args, shared_object: Optional[SharedObject] = None, **kwargs - ) -> None: - super().__init__(proc_id, init, *args, shared_object=shared_object, **kwargs) + def __init__(self, proc_id: int, init: Union[Callable, object], *args, **kwargs) -> None: + super().__init__(proc_id, init, *args, **kwargs) self._thread = None def start(self, recv_queue: queue.Queue): - self._recv_queue = recv_queue - self._send_queue = queue.Queue() - thread = threading.Thread( - target=self._target, - args=(self._proc_id, self._init, self._args, self._kwargs, self._send_queue, self._recv_queue), - name="supervisor_child_{}_{}".format(self._proc_id, time.time()), - daemon=True - ) - thread.start() - self._thread = thread + if self._thread is None: + self._recv_queue = recv_queue + self._send_queue = queue.Queue() + thread = threading.Thread( + target=self._target, + args=(self._proc_id, self._init, self._send_queue, self._recv_queue), + name="supervisor_child_{}_{}".format(self._proc_id, time.time()), + daemon=True + ) + thread.start() + self._thread = thread def shutdown(self, timeout: Optional[float] = None): if self._thread: @@ -187,6 +198,9 @@ def shutdown(self, timeout: Optional[float] = None): self._send_queue = None def send(self, payload: SendPayload): + if self._send_queue is None: + logging.warning("Child worker has been terminated or not started.") + return self._send_queue.put(payload) @@ -194,26 +208,32 @@ class Supervisor: TYPE_MAPPING = {ChildType.PROCESS: ChildProcess, ChildType.THREAD: ChildThread} - QUEUE_MAPPING = { - ChildType.PROCESS: mp.get_context('spawn' if platform.system().lower() == 'windows' else 'fork').Queue, - ChildType.THREAD: queue.Queue - } - - def __init__(self, type_: ChildType) -> None: + def __init__(self, type_: ChildType, mp_ctx: Optional[BaseContext] = None) -> None: self._children: List[Child] = [] self._type = type_ self._child_class = self.TYPE_MAPPING[self._type] self._running = False self.__queue = None + self._mp_ctx = mp_ctx or get_mp_ctx() - def register(self, init: Callable, *args, shared_object: Optional[SharedObject] = None, **kwargs) -> None: + def register( + self, + init: Union[Callable, object], + shm_buffer: Optional[Any] = None, + shm_callback: Optional[Callable] = None + ) -> None: proc_id = len(self._children) - self._children.append(self._child_class(proc_id, init, *args, shared_object=shared_object, **kwargs)) + self._children.append( + self._child_class(proc_id, init, shm_buffer=shm_buffer, shm_callback=shm_callback, mp_ctx=self._mp_ctx) + ) @property def _recv_queue(self) -> Union[queue.Queue, mp.Queue]: if not self.__queue: - self.__queue = self.QUEUE_MAPPING[self._type]() + if self._type is ChildType.PROCESS: + self.__queue = self._mp_ctx.Queue() + elif self._type is ChildType.THREAD: + self.__queue = queue.Queue() return self.__queue @_recv_queue.setter @@ -233,6 +253,9 @@ def send(self, payload: SendPayload) -> None: Arguments: - payload (:obj:`SendPayload`): Send payload. """ + if not self._running: + logging.warning("Please call start_link before sending any payload to child process.") + return self._children[payload.proc_id].send(payload) def recv(self, ignore_err: bool = False, timeout: float = None) -> RecvPayload: diff --git a/ding/framework/task.py b/ding/framework/task.py index d67c0558b4..ed3e14eb93 100644 --- a/ding/framework/task.py +++ b/ding/framework/task.py @@ -7,8 +7,11 @@ import concurrent.futures import fnmatch import math +import enum from types import GeneratorType from typing import Any, Awaitable, Callable, Dict, Generator, Iterable, List, Optional, Set, Union +import inspect + from ding.framework.context import Context from ding.framework.parallel import Parallel from ding.framework.event_loop import EventLoop @@ -50,11 +53,27 @@ def runtime_handler(task: "Task", *args, async_mode: Optional[bool] = None, **kw return runtime_handler +class Role(str, enum.Enum): + LEARNER = "learner" + COLLECTOR = "collector" + EVALUATOR = "evaluator" + + +class VoidMiddleware: + + def __call__(self, _): + return + + class Task: """ Tash will manage the execution order of the entire pipeline, register new middleware, and generate new context objects. """ + role = Role + + def __init__(self) -> None: + self.router = Parallel() def start( self, @@ -71,6 +90,7 @@ def start( self._wrappers = [] self.ctx = ctx or Context() self._backward_stack = OrderedDict() + self._roles = set() # Bind event loop functions self._event_loop = EventLoop("task_{}".format(id(self))) @@ -85,7 +105,6 @@ def start( self.labels = labels or set() # Parallel segment - self.router = Parallel() if async_mode or self.router.is_active: self._activate_async() @@ -99,6 +118,21 @@ def sync_finish(value): self.init_labels() return self + def add_role(self, role: Role): + self._roles.add(role) + + def has_role(self, role: Role) -> bool: + if len(self._roles) == 0: + return True + return role in self._roles + + @property + def roles(self) -> Set[Role]: + return self._roles + + def void(self): + return VoidMiddleware() + def init_labels(self): if self.async_mode: self.labels.add("async") @@ -120,6 +154,9 @@ def use(self, fn: Callable, lock: Union[bool, Lock] = False) -> 'Task': Returns: - task (:obj:`Task`): The task. """ + assert isinstance(fn, Callable), "Middleware function should be a callable object, current fn {}".format(fn) + if isinstance(fn, VoidMiddleware): # Skip void function + return self for wrapper in self._wrappers: fn = wrapper(fn) self._middleware.append(self.wrap(fn, lock=lock)) @@ -192,7 +229,6 @@ def wrap(self, fn: Callable, lock: Union[bool, Lock] = False) -> Callable: if lock is True: lock = self._thread_lock - @wraps(fn) def forward(ctx: Context): if lock: with lock: @@ -212,6 +248,11 @@ def backward(): return backward + if hasattr(fn, "__name__"): + forward = wraps(fn)(forward) + else: + forward = wraps(fn.__class__)(forward) + return forward @enable_async @@ -258,6 +299,10 @@ def backward(self, backward_stack: Optional[Dict[str, Generator]] = None) -> Non except StopIteration: continue + @property + def running(self): + return self._running + def serial(self, *fns: List[Callable]) -> Callable: """ Overview: diff --git a/ding/framework/tests/test_event_loop.py b/ding/framework/tests/test_event_loop.py index 1dddae6164..2f3545f3f5 100644 --- a/ding/framework/tests/test_event_loop.py +++ b/ding/framework/tests/test_event_loop.py @@ -31,6 +31,8 @@ def callback(n, lock): assert counter == 10 # Test once + counter = 0 + loop.once("count", callback) loop.once("count", callback) loop.emit("count", 10, lock) sleep(0.1) diff --git a/ding/framework/tests/test_parallel.py b/ding/framework/tests/test_parallel.py index 3c7c190f0c..429072a3fc 100644 --- a/ding/framework/tests/test_parallel.py +++ b/ding/framework/tests/test_parallel.py @@ -26,8 +26,8 @@ def test_callback(key): @pytest.mark.unittest def test_parallel_run(): - Parallel.runner(n_parallel_workers=2)(parallel_main) - Parallel.runner(n_parallel_workers=2, protocol="tcp")(parallel_main) + Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main) + Parallel.runner(n_parallel_workers=2, protocol="tcp", startup_interval=0.1)(parallel_main) def uncaught_exception_main(): @@ -43,7 +43,7 @@ def uncaught_exception_main(): def test_uncaught_exception(): # Make one process crash, then the parent process will also crash and output the stack of the wrong process. with pytest.raises(Exception) as exc_info: - Parallel.runner(n_parallel_workers=2, topology="mesh")(uncaught_exception_main) + Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(uncaught_exception_main) e = exc_info._excinfo[1] assert "uncaught exception" in str(e) @@ -52,6 +52,7 @@ def disconnected_main(): router = Parallel() if router.node_id == 0: + time.sleep(0.1) # Receive two messages then exit greets = [] router.on("greeting", lambda: greets.append(".")) @@ -73,7 +74,7 @@ def disconnected_main(): def test_disconnected(): # Make one process exit normally and the rest will still run, even if the network request # is not received by other processes. - Parallel.runner(n_parallel_workers=2, topology="mesh")(disconnected_main) + Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(disconnected_main) class AutoRecover: @@ -143,9 +144,13 @@ def main(cls): @pytest.mark.unittest def test_auto_recover(): # With max_retries=1 - Parallel.runner(n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=1)(AutoRecover.main) + Parallel.runner( + n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=1, startup_interval=0.1 + )(AutoRecover.main) # With max_retries=0 with pytest.raises(Exception) as exc_info: - Parallel.runner(n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=0)(AutoRecover.main) + Parallel.runner( + n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=0, startup_interval=0.1 + )(AutoRecover.main) e = exc_info._excinfo[1] assert "P1 Error" in str(e) diff --git a/ding/framework/tests/test_supervisor.py b/ding/framework/tests/test_supervisor.py index 57a0f0d49f..b4fdb95dc0 100644 --- a/ding/framework/tests/test_supervisor.py +++ b/ding/framework/tests/test_supervisor.py @@ -1,9 +1,9 @@ import multiprocessing as mp import ctypes -from time import sleep +from time import sleep, time from typing import Any, Dict, List import pytest -from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType, SharedObject +from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType class MockEnv(): @@ -25,13 +25,16 @@ def block(self): def block_reset(self): sleep(10) + def sleep1(self): + sleep(1) + @pytest.mark.unittest @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD]) def test_supervisor(type_): sv = Supervisor(type_=type_) for _ in range(3): - sv.register(MockEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) sv.start_link() for env_id in range(len(sv._children)): @@ -71,6 +74,25 @@ def test_supervisor(type_): sv.shutdown() +@pytest.mark.unittest +def test_supervisor_spawn(): + sv = Supervisor(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) + for _ in range(3): + sv.register(MockEnv("AnyArgs")) + sv.start_link() + + for env_id in range(len(sv._children)): + sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"])) + + recv_states: List[RecvPayload] = [] + for _ in range(3): + recv_states.append(sv.recv()) + + assert sum([payload.proc_id for payload in recv_states]) == 3 + assert all([payload.data == 1 for payload in recv_states]) + sv.shutdown() + + class MockCrashEnv(MockEnv): def step(self, _): @@ -86,8 +108,8 @@ def step(self, _): def test_crash_supervisor(type_): sv = Supervisor(type_=type_) for _ in range(2): - sv.register(MockEnv, "AnyArgs") - sv.register(MockCrashEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) + sv.register(lambda: MockCrashEnv("AnyArgs")) sv.start_link() # Send 6 messages, will cause the third subprocess crash @@ -126,7 +148,7 @@ def test_crash_supervisor(type_): def test_recv_all(type_): sv = Supervisor(type_=type_) for _ in range(3): - sv.register(MockEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) sv.start_link() # Test recv_all @@ -162,7 +184,7 @@ def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayl def test_timeout(type_): sv = Supervisor(type_=type_) for _ in range(3): - sv.register(MockEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) sv.start_link() send_payloads = [] @@ -202,7 +224,7 @@ def test_timeout(type_): def test_timeout_with_callback(type_): sv = Supervisor(type_=type_) for _ in range(3): - sv.register(MockEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) sv.start_link() send_payloads = [] @@ -239,25 +261,50 @@ def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayl sv.shutdown(timeout=1) -@pytest.mark.unittest +@pytest.mark.tmp # gitlab ci and local test pass, github always fail def test_shared_memory(): sv = Supervisor(type_=ChildType.PROCESS) def shm_callback(payload: RecvPayload, shm: Any): - shm[payload.proc_id] = payload.data + shm[payload.proc_id] = payload.req_id payload.data = 0 shm = mp.Array(ctypes.c_uint8, 3) for i in range(3): - sv.register(MockEnv, "AnyArgs", shared_object=SharedObject(buf=shm, callback=shm_callback)) + sv.register(lambda: MockEnv("AnyArgs"), shm_buffer=shm, shm_callback=shm_callback) sv.start_link() + # Send init request for env_id in range(len(sv._children)): - sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"])) + sv.send(SendPayload(proc_id=env_id, req_id=env_id, method="sleep1", args=[])) - for i in range(3): + start = time() + for i in range(6): payload = sv.recv() assert payload.data == 0 - assert shm[payload.proc_id] == 1 + assert shm[payload.proc_id] == payload.req_id + sv.send(SendPayload(proc_id=payload.proc_id, req_id=i, method="sleep1", args=[])) + + # Non blocking + assert time() - start < 3 sv.shutdown() + + +@pytest.mark.benchmark +@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD]) +def test_supervisor_benchmark(type_): + sv = Supervisor(type_=type_) + for _ in range(3): + sv.register(lambda: MockEnv("AnyArgs")) + sv.start_link() + + for env_id in range(len(sv._children)): + sv.send(SendPayload(proc_id=env_id, method="step", args=[""])) + + start = time() + for _ in range(1000): + payload = sv.recv() + sv.send(SendPayload(proc_id=payload.proc_id, method="step", args=[""])) + + assert time() - start < 1 diff --git a/ding/framework/tests/test_task.py b/ding/framework/tests/test_task.py index 18e3325396..8b6f9ee1de 100644 --- a/ding/framework/tests/test_task.py +++ b/ding/framework/tests/test_task.py @@ -1,6 +1,6 @@ import multiprocessing as mp import pytest -from threading import Lock, Thread +from threading import Lock from time import sleep, time import random import dataclasses @@ -126,7 +126,7 @@ def _counter(ctx): @pytest.mark.unittest def test_parallel_pipeline(): - Parallel.runner(n_parallel_workers=2)(parallel_main) + Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main) @pytest.mark.unittest @@ -163,7 +163,7 @@ def emit_remote_main(): @pytest.mark.unittest def test_emit_remote(): - Parallel.runner(n_parallel_workers=2)(emit_remote_main) + Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(emit_remote_main) @pytest.mark.unittest @@ -229,7 +229,7 @@ def early_stop_main(): @pytest.mark.unittest def test_early_stop(): - Parallel.runner(n_parallel_workers=2)(early_stop_main) + Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(early_stop_main) @pytest.mark.unittest @@ -350,7 +350,7 @@ def tick(ctx: Context): def broadcast_main_target(): Parallel.runner( - n_parallel_workers=1, protocol="tcp", address="127.0.0.1", topology="mesh", ports=50555 + n_parallel_workers=1, protocol="tcp", address="127.0.0.1", topology="mesh", ports=50555, startup_interval=0.1 )(broadcast_finish_main) @@ -363,11 +363,12 @@ def broadcast_secondary_target(): topology="alone", ports=50556, attach_to=["tcp://127.0.0.1:50555"], - node_ids=[1, 2] + node_ids=[1, 2], + startup_interval=0.1 )(broadcast_finish_main) -@pytest.mark.unittest +@pytest.mark.tmp # gitlab ci and local test pass, github always fail @pytest.mark.timeout(10) def test_broadcast_finish(): start = time() diff --git a/ding/framework/wrapper/step_timer.py b/ding/framework/wrapper/step_timer.py index f7d123bc62..dfabdd1476 100644 --- a/ding/framework/wrapper/step_timer.py +++ b/ding/framework/wrapper/step_timer.py @@ -5,11 +5,20 @@ import numpy as np import time from ditk import logging +from ding.framework import task class StepTimer: def __init__(self, print_per_step: int = 1, smooth_window: int = 10) -> None: + """ + Overview: + Print time cost of each step (execute one middleware). + Arguments: + - print_per_step (:obj:`int`): Print each N step. + - smooth_window (:obj:`int`): The window size to smooth the mean. + """ + self.print_per_step = print_per_step self.records = defaultdict(lambda: deque(maxlen=print_per_step * smooth_window)) @@ -36,11 +45,12 @@ def executor(ctx): time_cost += time.time() - start_time else: time_cost = time.time() - start_time - self.records[step_name].append(time_cost * 1000) + self.records[step_name].append(time_cost) if ctx.total_step % self.print_per_step == 0: logging.info( - "[Step Timer] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format( - step_name, time_cost * 1000, np.mean(self.records[step_name]) + "[Step Timer][Node:{:>2}] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format( + task.router.node_id or 0, step_name, time_cost * 1000, + np.mean(self.records[step_name]) * 1000 ) ) From 018b1974108d012ee914b8b489c9a3d392728f7c Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 19 Oct 2022 17:19:48 +0800 Subject: [PATCH 62/70] style(nyz): correct yapf style --- ding/framework/context.py | 2 +- ding/framework/middleware/functional/evaluator.py | 4 +--- ding/framework/middleware/functional/logger.py | 3 +-- ding/framework/middleware/functional/trainer.py | 4 +--- ding/framework/middleware/tests/test_logger.py | 3 +-- 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/ding/framework/context.py b/ding/framework/context.py index d214f60987..1dbe998b49 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -92,4 +92,4 @@ def __post_init__(self): # This method is called just after __init__ method. Here, concretely speaking, # this method is called just after the object initialize its fields. # We use this method here to keep the fields needed for each iteration. - self.keep('train_iter', 'last_eval_iter') \ No newline at end of file + self.keep('train_iter', 'last_eval_iter') diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index ee9e9841a3..4e6d27bf1c 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -7,15 +7,13 @@ import treetensor.torch as ttorch from easydict import EasyDict from ding.envs import BaseEnvManager -from ding.framework.context import OfflineRLContext +from ding.framework.context import Context, OfflineRLContext, OnlineRLContext from ding.policy import Policy from ding.data import Dataset, DataLoader from ding.framework import task from ding.torch_utils import tensor_to_list, to_ndarray, get_shape0 from ding.utils import lists_to_dicts -from ding.framework import Context, OnlineRLContext, OfflineRLContext - class IMetric(ABC): diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index dd887dc25e..b02c4b9013 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -11,7 +11,6 @@ from ding.torch_utils import to_ndarray from ding.utils.default_helper import one_time_warning - if TYPE_CHECKING: from ding.framework import OnlineRLContext, OfflineRLContext @@ -183,4 +182,4 @@ def _plot(ctx: "OnlineRLContext"): } ) - return _plot \ No newline at end of file + return _plot diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index 86dabfdb22..7cd1855553 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -3,9 +3,7 @@ from ditk import logging import numpy as np from ding.policy import Policy -from ding.framework import task, OfflineRLContext - -from ding.framework import OnlineRLContext, OfflineRLContext +from ding.framework import task, OfflineRLContext, OnlineRLContext def trainer(cfg: EasyDict, policy: Policy) -> Callable: diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index bb67af55cb..df204788ec 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -13,7 +13,6 @@ from ding.framework import OnlineRLContext, OfflineRLContext from ding.framework.middleware.functional import online_logger, offline_logger, wandb_online_logger - test_folder = "test_exp" test_path = path.join(os.getcwd(), test_folder) cfg = EasyDict({"exp_name": "test_exp"}) @@ -204,4 +203,4 @@ def test_wandb_online_logger_gradient(): wandb_online_logger(cfg, env, model)(ctx) test_wandb_online_logger_metric() - test_wandb_online_logger_gradient() \ No newline at end of file + test_wandb_online_logger_gradient() From 7d931f9c5f15800a5e61ba3b86562353724f438a Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 19 Oct 2022 18:58:59 +0800 Subject: [PATCH 63/70] fix(nyz): fix ctx and logger compatibility bugs --- ding/config/config.py | 2 +- ding/envs/env_manager/base_env_manager.py | 20 +++++++++---------- ding/framework/middleware/distributer.py | 6 ++++-- .../framework/middleware/functional/logger.py | 12 ++++++++++- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/ding/config/config.py b/ding/config/config.py index e083e09e40..5587e11866 100644 --- a/ding/config/config.py +++ b/ding/config/config.py @@ -315,7 +315,7 @@ def save_project_state(exp_name: str) -> None: def _fn(cmd: str): return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.strip().decode("utf-8") - if subprocess.run("git status", shell=True, stderr=subprocess.PIPE).returncode == 0: + if subprocess.run("git status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0: short_sha = _fn("git describe --always") log = _fn("git log --stat -n 5") diff = _fn("git diff") diff --git a/ding/envs/env_manager/base_env_manager.py b/ding/envs/env_manager/base_env_manager.py index ff23758c52..a5ddd5e85a 100644 --- a/ding/envs/env_manager/base_env_manager.py +++ b/ding/envs/env_manager/base_env_manager.py @@ -122,16 +122,6 @@ def __init__( self._action_space = self._env_ref.action_space self._reward_space = self._env_ref.reward_space self._env_ref.close() - try: - global space_log_flag - if space_log_flag: - logging.info("Env Space Information:") - logging.info("\tObservation Space: {}".format(self._observation_space)) - logging.info("\tAction Space: {}".format(self._action_space)) - logging.info("\tReward Space: {}".format(self._reward_space)) - space_log_flag = False - except: - pass self._env_states = {i: EnvState.VOID for i in range(self._env_num)} self._env_seed = {i: None for i in range(self._env_num)} self._episode_num = self._cfg.episode_num @@ -238,6 +228,16 @@ def launch(self, reset_param: Optional[Dict] = None) -> None: value is the cooresponding reset parameters. """ assert self._closed, "Please first close the env manager" + try: + global space_log_flag + if space_log_flag: + logging.info("Env Space Information:") + logging.info("\tObservation Space: {}".format(self._observation_space)) + logging.info("\tAction Space: {}".format(self._action_space)) + logging.info("\tReward Space: {}".format(self._reward_space)) + space_log_flag = False + except: + pass if reset_param is not None: assert len(reset_param) == len(self._env_fn) self._create_state() diff --git a/ding/framework/middleware/distributer.py b/ding/framework/middleware/distributer.py index 815a78bca9..dc68d437f8 100644 --- a/ding/framework/middleware/distributer.py +++ b/ding/framework/middleware/distributer.py @@ -1,4 +1,5 @@ from time import sleep, time +from dataclasses import fields from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union from ditk import logging from ding.framework import task @@ -91,7 +92,8 @@ def fetch(self, ctx: "Context") -> Dict[str, Any]: Each attribute should have a standalone fetch handler, which named `_fetch_{key}` """ payload = {} - for key, item in ctx.items(): + for field in fields(ctx): + key, item = field.name, getattr(ctx, field.name) fn_name = "_fetch_{}".format(key) if hasattr(self, fn_name): value = getattr(self, fn_name)(item) @@ -114,7 +116,7 @@ def merge(self, ctx: "Context"): sleep(0.01) for k, v in self._state.items(): - ctx[k] = v + setattr(ctx, k, v) self._state = {} # Handle each attibute of context diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index b02c4b9013..466f0fb17c 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -6,6 +6,7 @@ from torch.nn import functional as F import numpy as np import wandb +from ding.framework import task from ding.envs import BaseEnvManagerV2 from ding.utils import DistributedWriter from ding.torch_utils import to_ndarray @@ -16,10 +17,14 @@ def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() writer = DistributedWriter.get_instance() last_train_show_iter = -1 def _logger(ctx: "OnlineRLContext"): + if task.finish: + writer.close() nonlocal last_train_show_iter if not np.isinf(ctx.eval_value): @@ -56,9 +61,13 @@ def _logger(ctx: "OnlineRLContext"): def offline_logger() -> Callable: + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() writer = DistributedWriter.get_instance() def _logger(ctx: "OfflineRLContext"): + if task.finish: + writer.close() if not np.isinf(ctx.eval_value): writer.add_scalar('basic/eval_episode_reward_mean-train_iter', ctx.eval_value, ctx.train_iter) if ctx.train_output is not None: @@ -91,7 +100,8 @@ def wandb_online_logger(cfg: EasyDict, env: BaseEnvManagerV2, model) -> Callable - env (:obj:`BaseEnvManagerV2`): Evaluator environment. - model (:obj:`nn.Module`): Model. ''' - + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] metric_list = ["q_value", "target q_value", "loss", "lr", "entropy"] # Initialize wandb with default settings From f096cfd345db2fa761166c75499a0a8e5325cf41 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 19 Oct 2022 18:59:31 +0800 Subject: [PATCH 64/70] polish(nyz): update demo from cartpole v0 to v1 --- ding/example/c51_nstep.py | 4 ++-- ding/example/dqn.py | 14 +++++++------- ding/example/dqn_new_env.py | 4 ++-- ding/example/dqn_nstep.py | 4 ++-- ding/example/dqn_per.py | 4 ++-- ding/example/dqn_rnd.py | 4 ++-- ding/example/iqn_nstep.py | 4 ++-- ding/example/ppg_offpolicy.py | 4 ++-- ding/example/ppo.py | 4 ++-- ding/example/ppo_offpolicy.py | 4 ++-- ding/example/qrdqn_nstep.py | 4 ++-- ding/example/r2d2.py | 4 ++-- ding/example/sqil.py | 6 +++--- ding/example/sql.py | 4 ++-- ding/example/trex.py | 4 ++-- 15 files changed, 36 insertions(+), 36 deletions(-) diff --git a/ding/example/c51_nstep.py b/ding/example/c51_nstep.py index 0c975957e1..b88fe0d244 100644 --- a/ding/example/c51_nstep.py +++ b/ding/example/c51_nstep.py @@ -20,11 +20,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/dqn.py b/ding/example/dqn.py index d7b80aa79a..8e36834bbe 100644 --- a/ding/example/dqn.py +++ b/ding/example/dqn.py @@ -3,15 +3,15 @@ Use the pipeline on a single process: -> python ding/example/dqn.py +> python3 -u ding/example/dqn.py Use the pipeline on multiple processes: -We surpose there are N processes (workers) = 1 learner + 1 evaluator + (N-2) actors +We surpose there are N processes (workers) = 1 learner + 1 evaluator + (N-2) collectors ## First Example —— Execute on one machine with multi processes. -Execute 4 processes with 1 learner + 1 evaluator + 2 actors +Execute 4 processes with 1 learner + 1 evaluator + 2 collectors Remember to keep them connected by mesh to ensure that they can exchange information with each other. > ditask --package . --main ding.example.dqn.main --parallel-workers 4 --topology mesh @@ -53,15 +53,15 @@ def main(): logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, auto=True) + cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) @@ -99,4 +99,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/ding/example/dqn_new_env.py b/ding/example/dqn_new_env.py index 97b6085aaf..8579dfa242 100644 --- a/ding/example/dqn_new_env.py +++ b/ding/example/dqn_new_env.py @@ -20,12 +20,12 @@ def main(): with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = EnvSupervisor( type_=ChildType.THREAD, - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], **cfg.env.manager ) evaluator_env = EnvSupervisor( type_=ChildType.THREAD, - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], **cfg.env.manager ) diff --git a/ding/example/dqn_nstep.py b/ding/example/dqn_nstep.py index 896f2d03e5..ef5d85d50d 100644 --- a/ding/example/dqn_nstep.py +++ b/ding/example/dqn_nstep.py @@ -20,11 +20,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/dqn_per.py b/ding/example/dqn_per.py index 31196b1552..caffa8a484 100644 --- a/ding/example/dqn_per.py +++ b/ding/example/dqn_per.py @@ -22,11 +22,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/dqn_rnd.py b/ding/example/dqn_rnd.py index bbf605ed83..af6db6b7cc 100644 --- a/ding/example/dqn_rnd.py +++ b/ding/example/dqn_rnd.py @@ -19,11 +19,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/iqn_nstep.py b/ding/example/iqn_nstep.py index 49e5a99f84..81a9d0446d 100644 --- a/ding/example/iqn_nstep.py +++ b/ding/example/iqn_nstep.py @@ -20,11 +20,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/ppg_offpolicy.py b/ding/example/ppg_offpolicy.py index ab07efaa73..037f98348a 100644 --- a/ding/example/ppg_offpolicy.py +++ b/ding/example/ppg_offpolicy.py @@ -19,11 +19,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/ppo.py b/ding/example/ppo.py index b2cd7aed01..a63a57fce0 100644 --- a/ding/example/ppo.py +++ b/ding/example/ppo.py @@ -19,11 +19,11 @@ def main(): ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/ppo_offpolicy.py b/ding/example/ppo_offpolicy.py index 5f3b11c579..7610928ea5 100644 --- a/ding/example/ppo_offpolicy.py +++ b/ding/example/ppo_offpolicy.py @@ -19,11 +19,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/qrdqn_nstep.py b/ding/example/qrdqn_nstep.py index f6b6b5a95a..db400c0fac 100644 --- a/ding/example/qrdqn_nstep.py +++ b/ding/example/qrdqn_nstep.py @@ -20,11 +20,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/r2d2.py b/ding/example/r2d2.py index f537ba990e..b7e4c9d52a 100644 --- a/ding/example/r2d2.py +++ b/ding/example/r2d2.py @@ -18,11 +18,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/sqil.py b/ding/example/sqil.py index 2443d9d6c5..3ac7dc0534 100644 --- a/ding/example/sqil.py +++ b/ding/example/sqil.py @@ -24,15 +24,15 @@ def main(): expert_cfg.policy.collect.n_sample = cfg.policy.collect.n_sample with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) expert_collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/sql.py b/ding/example/sql.py index a1c3a75680..999eef4733 100644 --- a/ding/example/sql.py +++ b/ding/example/sql.py @@ -18,11 +18,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/trex.py b/ding/example/trex.py index 7c43f7fc79..9d2ffd9768 100644 --- a/ding/example/trex.py +++ b/ding/example/trex.py @@ -28,12 +28,12 @@ def main(): with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) From 8d34671b6c52aae860ad2eaa42dd86a760780472 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 19 Oct 2022 20:44:55 +0800 Subject: [PATCH 65/70] fix(nyz): fix evaluator condition bug --- ding/framework/middleware/functional/evaluator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 4e6d27bf1c..a09abd317f 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -170,7 +170,7 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): """ # evaluation will be executed if the task begins or enough train_iter after last evaluation - if ctx.last_eval_iter is not None and \ + if ctx.last_eval_iter != -1 and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): return @@ -222,7 +222,7 @@ def metric_evaluator(cfg: EasyDict, policy: Policy, dataset: Dataset, metric: IM def _evaluate(ctx: "Context"): # evaluation will be executed if the task begins or enough train_iter after last evaluation - if ctx.last_eval_iter is not None and \ + if ctx.last_eval_iter != -1 and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): return From 590728f6143bd3c55b93fbaafddcf7d21cae85cb Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Mon, 14 Nov 2022 12:37:16 +0800 Subject: [PATCH 66/70] style(nyz): correct flake8 style --- ding/framework/middleware/functional/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index b8c8ac671d..98197b1ef3 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -12,4 +12,4 @@ from .advantage_estimator import gae_estimator from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer -from .timer import epoch_timer \ No newline at end of file +from .timer import epoch_timer From 2236e8fd8a499b5a9574dfea6194e5b1ba24fc17 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 17 Nov 2022 18:14:39 +0800 Subject: [PATCH 67/70] demo(nyz): move back to CartPole-v0 --- ding/example/c51_nstep.py | 4 ++-- ding/example/dqn.py | 4 ++-- ding/example/dqn_new_env.py | 4 ++-- ding/example/dqn_nstep.py | 4 ++-- ding/example/dqn_per.py | 4 ++-- ding/example/dqn_rnd.py | 4 ++-- ding/example/iqn_nstep.py | 4 ++-- ding/example/ppg_offpolicy.py | 4 ++-- ding/example/ppo.py | 4 ++-- ding/example/ppo_offpolicy.py | 4 ++-- ding/example/qrdqn_nstep.py | 4 ++-- ding/example/r2d2.py | 4 ++-- ding/example/sqil.py | 6 +++--- ding/example/sql.py | 4 ++-- ding/example/trex.py | 4 ++-- 15 files changed, 31 insertions(+), 31 deletions(-) diff --git a/ding/example/c51_nstep.py b/ding/example/c51_nstep.py index b88fe0d244..0c975957e1 100644 --- a/ding/example/c51_nstep.py +++ b/ding/example/c51_nstep.py @@ -20,11 +20,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/dqn.py b/ding/example/dqn.py index 8e36834bbe..c3670def0a 100644 --- a/ding/example/dqn.py +++ b/ding/example/dqn.py @@ -57,11 +57,11 @@ def main(): ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/dqn_new_env.py b/ding/example/dqn_new_env.py index 8579dfa242..97b6085aaf 100644 --- a/ding/example/dqn_new_env.py +++ b/ding/example/dqn_new_env.py @@ -20,12 +20,12 @@ def main(): with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = EnvSupervisor( type_=ChildType.THREAD, - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], **cfg.env.manager ) evaluator_env = EnvSupervisor( type_=ChildType.THREAD, - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], **cfg.env.manager ) diff --git a/ding/example/dqn_nstep.py b/ding/example/dqn_nstep.py index ef5d85d50d..896f2d03e5 100644 --- a/ding/example/dqn_nstep.py +++ b/ding/example/dqn_nstep.py @@ -20,11 +20,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/dqn_per.py b/ding/example/dqn_per.py index caffa8a484..31196b1552 100644 --- a/ding/example/dqn_per.py +++ b/ding/example/dqn_per.py @@ -22,11 +22,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/dqn_rnd.py b/ding/example/dqn_rnd.py index af6db6b7cc..bbf605ed83 100644 --- a/ding/example/dqn_rnd.py +++ b/ding/example/dqn_rnd.py @@ -19,11 +19,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/iqn_nstep.py b/ding/example/iqn_nstep.py index 81a9d0446d..49e5a99f84 100644 --- a/ding/example/iqn_nstep.py +++ b/ding/example/iqn_nstep.py @@ -20,11 +20,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/ppg_offpolicy.py b/ding/example/ppg_offpolicy.py index 037f98348a..ab07efaa73 100644 --- a/ding/example/ppg_offpolicy.py +++ b/ding/example/ppg_offpolicy.py @@ -19,11 +19,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/ppo.py b/ding/example/ppo.py index a63a57fce0..b2cd7aed01 100644 --- a/ding/example/ppo.py +++ b/ding/example/ppo.py @@ -19,11 +19,11 @@ def main(): ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/ppo_offpolicy.py b/ding/example/ppo_offpolicy.py index 7610928ea5..5f3b11c579 100644 --- a/ding/example/ppo_offpolicy.py +++ b/ding/example/ppo_offpolicy.py @@ -19,11 +19,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/qrdqn_nstep.py b/ding/example/qrdqn_nstep.py index db400c0fac..f6b6b5a95a 100644 --- a/ding/example/qrdqn_nstep.py +++ b/ding/example/qrdqn_nstep.py @@ -20,11 +20,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/r2d2.py b/ding/example/r2d2.py index b7e4c9d52a..f537ba990e 100644 --- a/ding/example/r2d2.py +++ b/ding/example/r2d2.py @@ -18,11 +18,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/sqil.py b/ding/example/sqil.py index 3ac7dc0534..2443d9d6c5 100644 --- a/ding/example/sqil.py +++ b/ding/example/sqil.py @@ -24,15 +24,15 @@ def main(): expert_cfg.policy.collect.n_sample = cfg.policy.collect.n_sample with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) expert_collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/sql.py b/ding/example/sql.py index 999eef4733..a1c3a75680 100644 --- a/ding/example/sql.py +++ b/ding/example/sql.py @@ -18,11 +18,11 @@ def main(): cfg = compile_config(main_config, create_cfg=create_config, auto=True) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) diff --git a/ding/example/trex.py b/ding/example/trex.py index 9d2ffd9768..7c43f7fc79 100644 --- a/ding/example/trex.py +++ b/ding/example/trex.py @@ -28,12 +28,12 @@ def main(): with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = BaseEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v1")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) From 6a6798f37edc085533954777d4c1ca83406e1ceb Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 1 Dec 2022 16:08:40 +0800 Subject: [PATCH 68/70] fix(nyz): fix context manager env step merge bug(ci skip) --- ding/framework/middleware/distributer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ding/framework/middleware/distributer.py b/ding/framework/middleware/distributer.py index dc68d437f8..ff5c092be2 100644 --- a/ding/framework/middleware/distributer.py +++ b/ding/framework/middleware/distributer.py @@ -26,6 +26,9 @@ def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] if not task.router.is_active: raise RuntimeError("ContextHandler should be used in parallel mode!") self._state = {} + self._local_state = {} # just save local state, not send to remote node + if task.has_role(task.role.COLLECTOR): + self._local_state['env_step'] = 0 self._event_name = "context_exchanger_{role}" self._skip_n_iter = skip_n_iter self._storage_loader = storage_loader @@ -153,13 +156,15 @@ def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]): if task.has_role(task.role.COLLECTOR): return trajectory_end_idx - def _put_env_step(self, env_step: int): + def _put_env_step(self, env_step_increment: int): if not task.has_role(task.role.COLLECTOR): - self._state["env_step"] = env_step + self._state["env_step"] += env_step_increment def _fetch_env_step(self, env_step: int): if task.has_role(task.role.COLLECTOR): - return env_step + env_step_increment = env_step - self._local_state['env_step'] + self._local_state['env_step'] = env_step + return env_step_increment def _put_env_episode(self, env_episode: int): if not task.has_role(task.role.COLLECTOR): From 060d1b8807f560ef9d4d2bedc326b5ba1da2c97a Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 1 Dec 2022 18:46:46 +0800 Subject: [PATCH 69/70] fix(nyz): fix context manager env step merge bug(ci skip) --- ding/framework/middleware/distributer.py | 29 +++++++++++++++++------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/ding/framework/middleware/distributer.py b/ding/framework/middleware/distributer.py index ff5c092be2..a4481b1bc0 100644 --- a/ding/framework/middleware/distributer.py +++ b/ding/framework/middleware/distributer.py @@ -29,6 +29,7 @@ def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] self._local_state = {} # just save local state, not send to remote node if task.has_role(task.role.COLLECTOR): self._local_state['env_step'] = 0 + self._local_state['env_episode'] = 0 self._event_name = "context_exchanger_{role}" self._skip_n_iter = skip_n_iter self._storage_loader = storage_loader @@ -119,7 +120,13 @@ def merge(self, ctx: "Context"): sleep(0.01) for k, v in self._state.items(): - setattr(ctx, k, v) + if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): + pure_k = k.split('increment_')[-1] + setattr(ctx, pure_k, getattr(ctx, pure_k) + v) + if k == 'increment_env_step': + pass #print(task.has_role(task.role.LEARNER), v) + else: + setattr(ctx, k, v) self._state = {} # Handle each attibute of context @@ -156,23 +163,29 @@ def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]): if task.has_role(task.role.COLLECTOR): return trajectory_end_idx - def _put_env_step(self, env_step_increment: int): + def _put_env_step(self, increment_env_step: int): if not task.has_role(task.role.COLLECTOR): - self._state["env_step"] += env_step_increment + if 'increment_env_step' not in self._state: + self._state['increment_env_step'] = 0 + self._state["increment_env_step"] += increment_env_step def _fetch_env_step(self, env_step: int): if task.has_role(task.role.COLLECTOR): - env_step_increment = env_step - self._local_state['env_step'] + increment_env_step = env_step - self._local_state['env_step'] self._local_state['env_step'] = env_step - return env_step_increment + return increment_env_step - def _put_env_episode(self, env_episode: int): + def _put_env_episode(self, increment_env_episode: int): if not task.has_role(task.role.COLLECTOR): - self._state["env_episode"] = env_episode + if 'increment_env_episode' not in self._state: + self._state['increment_env_episode'] = 0 + self._state["increment_env_episode"] += increment_env_episode def _fetch_env_episode(self, env_episode: int): if task.has_role(task.role.COLLECTOR): - return env_episode + increment_env_episode = env_episode - self._local_state['env_episode'] + self._local_state['env_episode'] = env_episode + return increment_env_episode def _put_train_iter(self, train_iter: int): if not task.has_role(task.role.LEARNER): From dde600961c98f690118ab8869f90c17c3f290490 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sun, 11 Dec 2022 23:38:35 +0800 Subject: [PATCH 70/70] fix(nyz): fix flake8 style --- ding/framework/middleware/ckpt_handler.py | 2 +- ding/framework/middleware/distributer.py | 2 -- ding/framework/middleware/functional/evaluator.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index 51097c9184..c23cf58b6d 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -67,4 +67,4 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: # finish if task.finish and self.save_finish: - save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) \ No newline at end of file + save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) diff --git a/ding/framework/middleware/distributer.py b/ding/framework/middleware/distributer.py index a4481b1bc0..c68a4b808f 100644 --- a/ding/framework/middleware/distributer.py +++ b/ding/framework/middleware/distributer.py @@ -123,8 +123,6 @@ def merge(self, ctx: "Context"): if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): pure_k = k.split('increment_')[-1] setattr(ctx, pure_k, getattr(ctx, pure_k) + v) - if k == 'increment_env_step': - pass #print(task.has_role(task.role.LEARNER), v) else: setattr(ctx, k, v) self._state = {} diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 25ce1fe076..f5553c4679 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -331,4 +331,4 @@ def _evaluate(ctx: "Context"): return _evaluate -# TODO battle evaluator \ No newline at end of file +# TODO battle evaluator