From 163be6923c06290290c7f9822c999e0909f7c256 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 4 Jan 2023 17:44:26 +0800 Subject: [PATCH 1/3] polish(nyz): polish example demos --- ding/config/config.py | 4 +++- ding/entry/tests/test_serial_entry.py | 8 ++++---- ding/example/c51_nstep.py | 2 +- ding/example/collect_demo_data.py | 2 +- ding/example/cql.py | 2 +- ding/example/d4pg.py | 2 +- ding/example/ddpg.py | 2 +- ding/example/dqn.py | 2 +- ding/example/dqn_her.py | 2 +- ding/example/dqn_new_env.py | 2 +- ding/example/dqn_nstep.py | 2 +- ding/example/dqn_per.py | 2 +- ding/example/dqn_rnd.py | 2 +- ding/example/iqn_nstep.py | 2 +- ding/example/pdqn.py | 2 +- ding/example/ppg_offpolicy.py | 8 ++++---- ding/example/ppo.py | 4 ++-- ding/example/ppo_offpolicy.py | 2 +- ding/example/qrdqn_nstep.py | 2 +- ding/example/r2d2.py | 2 +- ding/example/sac.py | 2 +- ding/example/sqil.py | 2 +- ding/example/sqil_continuous.py | 2 +- ding/example/sql.py | 2 +- ding/example/td3.py | 2 +- ding/example/trex.py | 6 +++--- ding/framework/middleware/ckpt_handler.py | 11 +++++----- ding/framework/middleware/collector.py | 2 +- ding/reward_model/trex_reward_model.py | 20 ++++++++++--------- .../cartpole/config/__init__.py | 2 +- .../cartpole/config/cartpole_c51_config.py | 1 + .../cartpole/config/cartpole_dqn_config.py | 1 + .../cartpole/config/cartpole_iqn_config.py | 1 + ...ig.py => cartpole_ppo_offpolicy_config.py} | 16 +++++++-------- .../cartpole/config/cartpole_sqil_config.py | 2 +- .../config/cartpole_trex_dqn_config.py | 7 ++----- .../config/pendulum_sqil_sac_config.py | 7 ++++--- 37 files changed, 74 insertions(+), 68 deletions(-) rename dizoo/classic_control/cartpole/config/{cartpole_offppo_config.py => cartpole_ppo_offpolicy_config.py} (74%) diff --git a/ding/config/config.py b/ding/config/config.py index 51a347de7c..574f3170e3 100644 --- a/ding/config/config.py +++ b/ding/config/config.py @@ -341,6 +341,7 @@ def compile_config( create_cfg: dict = None, save_cfg: bool = True, save_path: str = 'total_config.py', + renew_dir: bool = True, ) -> EasyDict: """ Overview: @@ -361,6 +362,7 @@ def compile_config( - create_cfg (:obj:`dict`): Input create config dict - save_cfg (:obj:`bool`): Save config or not - save_path (:obj:`str`): Path of saving file + - renew_dir (:obj:`bool`): Whether to new a directory for saving config. Returns: - cfg (:obj:`EasyDict`): Config after compiling """ @@ -460,7 +462,7 @@ def compile_config( if 'exp_name' not in cfg: cfg.exp_name = 'default_experiment' if save_cfg: - if os.path.exists(cfg.exp_name): + if os.path.exists(cfg.exp_name) and renew_dir: cfg.exp_name += datetime.datetime.now().strftime("_%y%m%d_%H%M%S") try: os.makedirs(cfg.exp_name) diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index 5d83f0557f..ddf72b7343 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -9,8 +9,8 @@ from dizoo.classic_control.cartpole.config.cartpole_dqn_stdim_config import cartpole_dqn_stdim_config, \ cartpole_dqn_stdim_create_config from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config -from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config, \ - cartpole_offppo_create_config +from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \ + cartpole_ppo_offpolicy_create_config from dizoo.classic_control.cartpole.config.cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config # noqa @@ -209,7 +209,7 @@ def test_qrdqn(): @pytest.mark.platformtest @pytest.mark.unittest def test_ppo(): - config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)] + config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] config[0].policy.learn.update_per_collect = 1 config[0].exp_name = 'ppo_offpolicy_unittest' try: @@ -221,7 +221,7 @@ def test_ppo(): @pytest.mark.platformtest @pytest.mark.unittest def test_ppo_nstep_return(): - config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)] + config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] config[0].policy.learn.update_per_collect = 1 config[0].policy.nstep_return = True try: diff --git a/ding/example/c51_nstep.py b/ding/example/c51_nstep.py index 0c975957e1..2b98ece213 100644 --- a/ding/example/c51_nstep.py +++ b/ding/example/c51_nstep.py @@ -40,7 +40,7 @@ def main(): task.use(nstep_reward_enhancer(cfg)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/collect_demo_data.py b/ding/example/collect_demo_data.py index da626934cc..76091dbabd 100644 --- a/ding/example/collect_demo_data.py +++ b/ding/example/collect_demo_data.py @@ -28,7 +28,7 @@ def main(): policy.collect_mode.load_state_dict(state_dict) task.use(StepCollector(cfg, policy.collect_mode, collector_env)) - task.use(offline_data_saver(cfg, cfg.policy.collect.save_path, data_type='hdf5')) + task.use(offline_data_saver(cfg.policy.collect.save_path, data_type='hdf5')) task.run(max_step=1) diff --git a/ding/example/cql.py b/ding/example/cql.py index 1e1c678dd0..5af78dabd3 100644 --- a/ding/example/cql.py +++ b/ding/example/cql.py @@ -33,7 +33,7 @@ def main(): task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(offline_data_fetcher(cfg, dataset)) task.use(trainer(cfg, policy.learn_mode)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.use(offline_logger()) task.run() diff --git a/ding/example/d4pg.py b/ding/example/d4pg.py index e004c0bed0..39806f166d 100644 --- a/ding/example/d4pg.py +++ b/ding/example/d4pg.py @@ -40,7 +40,7 @@ def main(): task.use(nstep_reward_enhancer(cfg)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/ddpg.py b/ding/example/ddpg.py index 9caa9b22f2..cf2fa4d8fb 100644 --- a/ding/example/ddpg.py +++ b/ding/example/ddpg.py @@ -37,7 +37,7 @@ def main(): ) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.use(termination_checker(max_train_iter=10000)) task.run() diff --git a/ding/example/dqn.py b/ding/example/dqn.py index c3670def0a..0959b3ab22 100644 --- a/ding/example/dqn.py +++ b/ding/example/dqn.py @@ -93,7 +93,7 @@ def main(): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(online_logger(train_show_freq=10)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/dqn_her.py b/ding/example/dqn_her.py index d600938af0..b88458aa33 100644 --- a/ding/example/dqn_her.py +++ b/ding/example/dqn_her.py @@ -38,7 +38,7 @@ def main(): task.use(EpisodeCollector(cfg, policy.collect_mode, collector_env)) task.use(data_pusher(cfg, buffer_)) task.use(HERLearner(cfg, policy.learn_mode, buffer_, her_reward_model)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/dqn_new_env.py b/ding/example/dqn_new_env.py index 97b6085aaf..e43a9a8187 100644 --- a/ding/example/dqn_new_env.py +++ b/ding/example/dqn_new_env.py @@ -40,7 +40,7 @@ def main(): task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/dqn_nstep.py b/ding/example/dqn_nstep.py index 896f2d03e5..09dc786d22 100644 --- a/ding/example/dqn_nstep.py +++ b/ding/example/dqn_nstep.py @@ -40,7 +40,7 @@ def main(): task.use(nstep_reward_enhancer(cfg)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.use(final_ctx_saver(cfg.exp_name)) task.run() diff --git a/ding/example/dqn_per.py b/ding/example/dqn_per.py index 31196b1552..fd6d736f8b 100644 --- a/ding/example/dqn_per.py +++ b/ding/example/dqn_per.py @@ -42,7 +42,7 @@ def main(): task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/dqn_rnd.py b/ding/example/dqn_rnd.py index bbf605ed83..2d5e1b93c3 100644 --- a/ding/example/dqn_rnd.py +++ b/ding/example/dqn_rnd.py @@ -40,7 +40,7 @@ def main(): task.use(trainer(cfg, reward_model)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_, reward_model=reward_model)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/iqn_nstep.py b/ding/example/iqn_nstep.py index 49e5a99f84..eff6df85bf 100644 --- a/ding/example/iqn_nstep.py +++ b/ding/example/iqn_nstep.py @@ -40,7 +40,7 @@ def main(): task.use(nstep_reward_enhancer(cfg)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/pdqn.py b/ding/example/pdqn.py index d470f140b7..5bc173d83c 100644 --- a/ding/example/pdqn.py +++ b/ding/example/pdqn.py @@ -37,7 +37,7 @@ def main(): task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000)) task.run() diff --git a/ding/example/ppg_offpolicy.py b/ding/example/ppg_offpolicy.py index ab07efaa73..70bd211cd5 100644 --- a/ding/example/ppg_offpolicy.py +++ b/ding/example/ppg_offpolicy.py @@ -1,7 +1,7 @@ import gym from ditk import logging from ding.model import PPG -from ding.policy import PPGPolicy +from ding.policy import PPGOffPolicy from ding.envs import DingEnvWrapper, BaseEnvManagerV2 from ding.data import DequeBuffer from ding.data.buffer.middleware import use_time_check, sample_range_view @@ -11,7 +11,7 @@ from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ CkptSaver, gae_estimator from ding.utils import set_pkg_seed -from dizoo.classic_control.cartpole.config.cartpole_ppg_config import main_config, create_config +from dizoo.classic_control.cartpole.config.cartpole_ppg_offpolicy_config import main_config, create_config def main(): @@ -39,13 +39,13 @@ def main(): value_buffer = buffer_.view() value_buffer.use(use_time_check(value_buffer, max_use=buffer_cfg.value.max_use)) value_buffer.use(sample_range_view(value_buffer, start=-buffer_cfg.value.replay_buffer_size)) - policy = PPGPolicy(cfg.policy, model=model) + policy = PPGOffPolicy(cfg.policy, model=model) task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(gae_estimator(cfg, policy.collect_mode, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, {'policy': policy_buffer, 'value': value_buffer})) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/ppo.py b/ding/example/ppo.py index b2cd7aed01..f7a76a26a2 100644 --- a/ding/example/ppo.py +++ b/ding/example/ppo.py @@ -35,8 +35,8 @@ def main(): task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(gae_estimator(cfg, policy.collect_mode)) - task.use(multistep_trainer(cfg, policy.learn_mode)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(multistep_trainer(policy.learn_mode, log_freq=50)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.use(online_logger(train_show_freq=3)) task.run() diff --git a/ding/example/ppo_offpolicy.py b/ding/example/ppo_offpolicy.py index 5f3b11c579..738b27f230 100644 --- a/ding/example/ppo_offpolicy.py +++ b/ding/example/ppo_offpolicy.py @@ -37,7 +37,7 @@ def main(): task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(gae_estimator(cfg, policy.collect_mode, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/qrdqn_nstep.py b/ding/example/qrdqn_nstep.py index f6b6b5a95a..352828cf35 100644 --- a/ding/example/qrdqn_nstep.py +++ b/ding/example/qrdqn_nstep.py @@ -40,7 +40,7 @@ def main(): task.use(nstep_reward_enhancer(cfg)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/r2d2.py b/ding/example/r2d2.py index f537ba990e..83fc617563 100644 --- a/ding/example/r2d2.py +++ b/ding/example/r2d2.py @@ -38,7 +38,7 @@ def main(): task.use(nstep_reward_enhancer(cfg)) task.use(data_pusher(cfg, buffer_, group_by_env=True)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/sac.py b/ding/example/sac.py index 8abb4ce1a5..9d154251a4 100644 --- a/ding/example/sac.py +++ b/ding/example/sac.py @@ -37,7 +37,7 @@ def main(): ) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.use(termination_checker(max_train_iter=10000)) task.use(online_logger()) task.run() diff --git a/ding/example/sqil.py b/ding/example/sqil.py index 2443d9d6c5..6df54a5724 100644 --- a/ding/example/sqil.py +++ b/ding/example/sqil.py @@ -57,7 +57,7 @@ def main(): task.use(StepCollector(cfg, expert_policy.collect_mode, expert_collector_env)) # expert data collector task.use(sqil_data_pusher(cfg, expert_buffer, expert=True)) task.use(OffPolicyLearner(cfg, policy.learn_mode, [(buffer_, 0.5), (expert_buffer, 0.5)])) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/sqil_continuous.py b/ding/example/sqil_continuous.py index 930a5cbb27..4fd0a6f795 100644 --- a/ding/example/sqil_continuous.py +++ b/ding/example/sqil_continuous.py @@ -61,7 +61,7 @@ def main(): ) # expert data collector task.use(sqil_data_pusher(cfg, expert_buffer, expert=True)) task.use(OffPolicyLearner(cfg, policy.learn_mode, [(buffer_, 0.5), (expert_buffer, 0.5)])) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.use(termination_checker(max_train_iter=10000)) task.run() diff --git a/ding/example/sql.py b/ding/example/sql.py index a1c3a75680..2c2a968082 100644 --- a/ding/example/sql.py +++ b/ding/example/sql.py @@ -37,7 +37,7 @@ def main(): task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/td3.py b/ding/example/td3.py index 59fa0c59cc..a02089f3d9 100644 --- a/ding/example/td3.py +++ b/ding/example/td3.py @@ -36,7 +36,7 @@ def main(): ) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/example/trex.py b/ding/example/trex.py index 7c43f7fc79..97611ba6c2 100644 --- a/ding/example/trex.py +++ b/ding/example/trex.py @@ -15,16 +15,16 @@ from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, \ eps_greedy_handler, CkptSaver, eps_greedy_masker, sqil_data_pusher, data_pusher from ding.utils import set_pkg_seed -from dizoo.classic_control.cartpole.config.cartpole_trex_dqn_config import main_config, create_config from ding.entry import trex_collecting_data from ding.reward_model import create_reward_model +from dizoo.classic_control.cartpole.config.cartpole_trex_dqn_config import main_config, create_config def main(): logging.getLogger().setLevel(logging.INFO) demo_arg = easydict.EasyDict({'cfg': [main_config, create_config], 'seed': 0}) trex_collecting_data(demo_arg) - cfg = compile_config(main_config, create_cfg=create_config, auto=True) + cfg = compile_config(main_config, create_cfg=create_config, auto=True, renew_dir=False) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( @@ -51,7 +51,7 @@ def main(): task.use(StepCollector(cfg, policy.collect_mode, collector_env)) task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_, reward_model)) - task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) task.run() diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index abc72ff3e3..2d71818576 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -54,11 +54,12 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: - eval_value (:obj:`float`): The episode return of current iteration. """ # train enough iteration - if self.train_freq and ctx.train_iter - self.last_save_iter >= self.train_freq: - save_file( - "{}/iteration_{}.pth.tar".format(self.prefix, ctx.train_iter), self.policy.learn_mode.state_dict() - ) - self.last_save_iter = ctx.train_iter + if self.train_freq : + if ctx.train_iter == 0 or ctx.train_iter - self.last_save_iter >= self.train_freq: + save_file( + "{}/iteration_{}.pth.tar".format(self.prefix, ctx.train_iter), self.policy.learn_mode.state_dict() + ) + self.last_save_iter = ctx.train_iter # best episode return so far if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value: diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index f2bed3cbe6..5fd9d12288 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -53,7 +53,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None: if self.random_collect_size > 0 and old < self.random_collect_size: target_size = self.random_collect_size - old random_policy = get_random_policy(self.cfg, self.policy, self.env) - current_inferencer = task.wrap(inferencer(self.cfg, random_policy, self.env)) + current_inferencer = task.wrap(inferencer(self.cfg.seed, random_policy, self.env)) else: # compatible with old config, a train sample = unroll_len step target_size = self.cfg.policy.collect.n_sample * self.cfg.policy.collect.unroll_len diff --git a/ding/reward_model/trex_reward_model.py b/ding/reward_model/trex_reward_model.py index 9fc069626c..adac2b1589 100644 --- a/ding/reward_model/trex_reward_model.py +++ b/ding/reward_model/trex_reward_model.py @@ -1,9 +1,10 @@ from collections.abc import Iterable -from easydict import EasyDict -import numpy as np -import pickle from copy import deepcopy from typing import Tuple, Optional, List, Dict +from easydict import EasyDict +import pickle +import os +import numpy as np import torch import torch.nn as nn @@ -196,14 +197,14 @@ def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> def load_expert_data(self) -> None: """ Overview: - Getting the expert data from ``config.data_path`` attribute in self + Getting the expert data. Effects: This is a side effect function which updates the expert data attribute \ (i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs`` """ - with open(self.cfg.reward_model.data_path + '/episodes_data.pkl', 'rb') as f: + with open(os.path.join(self.cfg.exp_name, 'episodes_data.pkl'), 'rb') as f: self.pre_expert_data = pickle.load(f) - with open(self.cfg.reward_model.data_path + '/learning_returns.pkl', 'rb') as f: + with open(os.path.join(self.cfg.exp_name, 'learning_returns.pkl'), 'rb') as f: self.learning_returns = pickle.load(f) self.create_training_data() @@ -316,12 +317,13 @@ def _train(self): item_loss = loss.item() cum_loss += item_loss if i % 100 == 99: - self._logger.info("epoch {}:{} loss {}".format(epoch, i, cum_loss)) + self._logger.info("[epoch {}:{}] loss {}".format(epoch, i, cum_loss)) self._logger.info("abs_returns: {}".format(abs_rewards)) cum_loss = 0.0 self._logger.info("check pointing") - torch.save(self.reward_model.state_dict(), self.cfg.reward_model.reward_model_path) - torch.save(self.reward_model.state_dict(), self.cfg.reward_model.reward_model_path) + if not os.path.exists(os.path.join(self.cfg.exp_name, 'ckpt_reward_model')): + os.makedirs(os.path.join(self.cfg.exp_name, 'ckpt_reward_model')) + torch.save(self.reward_model.state_dict(), os.path.join(self.cfg.exp_name, 'ckpt_reward_model/latest.pth.tar')) self._logger.info("finished training") def train(self): diff --git a/dizoo/classic_control/cartpole/config/__init__.py b/dizoo/classic_control/cartpole/config/__init__.py index 64fd50b9b2..6c5d7ebc06 100644 --- a/dizoo/classic_control/cartpole/config/__init__.py +++ b/dizoo/classic_control/cartpole/config/__init__.py @@ -7,7 +7,7 @@ from .cartpole_gcl_config import cartpole_gcl_ppo_onpolicy_config, cartpole_gcl_ppo_onpolicy_create_config from .cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config from .cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config -from .cartpole_offppo_config import cartpole_offppo_config, cartpole_offppo_create_config +from .cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config from .cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config from .cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config from .cartpole_qrdqn_config import cartpole_qrdqn_config, cartpole_qrdqn_create_config diff --git a/dizoo/classic_control/cartpole/config/cartpole_c51_config.py b/dizoo/classic_control/cartpole/config/cartpole_c51_config.py index dfb4733a1c..a9f8557292 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_c51_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_c51_config.py @@ -31,6 +31,7 @@ n_sample=80, unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=40, )), other=dict( eps=dict( type='exp', diff --git a/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py b/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py index ca34d7d879..3e5ca613d0 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py @@ -21,6 +21,7 @@ nstep=1, discount_factor=0.97, learn=dict( + update_per_collect=5, batch_size=64, learning_rate=0.001, ), diff --git a/dizoo/classic_control/cartpole/config/cartpole_iqn_config.py b/dizoo/classic_control/cartpole/config/cartpole_iqn_config.py index d869005b65..d6fca73e93 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_iqn_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_iqn_config.py @@ -29,6 +29,7 @@ n_sample=80, unroll_len=1, ), + eval=dict(evaluator=dict(eval_freq=40, )), other=dict( eps=dict( type='exp', diff --git a/dizoo/classic_control/cartpole/config/cartpole_offppo_config.py b/dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py similarity index 74% rename from dizoo/classic_control/cartpole/config/cartpole_offppo_config.py rename to dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py index a720fec55d..f952ecadaa 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_offppo_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py @@ -1,7 +1,7 @@ from easydict import EasyDict -cartpole_offppo_config = dict( - exp_name='cartpole_offppo_seed0', +cartpole_ppo_offpolicy_config = dict( + exp_name='cartpole_ppo_offpolicy_seed0', env=dict( collector_env_num=8, evaluator_env_num=5, @@ -37,9 +37,9 @@ other=dict(replay_buffer=dict(replay_buffer_size=5000)) ), ) -cartpole_offppo_config = EasyDict(cartpole_offppo_config) -main_config = cartpole_offppo_config -cartpole_offppo_create_config = dict( +cartpole_ppo_offpolicy_config = EasyDict(cartpole_ppo_offpolicy_config) +main_config = cartpole_ppo_offpolicy_config +cartpole_ppo_offpolicy_create_config = dict( env=dict( type='cartpole', import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], @@ -47,10 +47,10 @@ env_manager=dict(type='base'), policy=dict(type='ppo_offpolicy'), ) -cartpole_offppo_create_config = EasyDict(cartpole_offppo_create_config) -create_config = cartpole_offppo_create_config +cartpole_ppo_offpolicy_create_config = EasyDict(cartpole_ppo_offpolicy_create_config) +create_config = cartpole_ppo_offpolicy_create_config if __name__ == "__main__": - # or you can enter `ding -m serial -c cartpole_offppo_config.py -s 0` + # or you can enter `ding -m serial -c cartpole_ppo_offpolicy_config.py -s 0` from ding.entry import serial_pipeline serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/classic_control/cartpole/config/cartpole_sqil_config.py b/dizoo/classic_control/cartpole/config/cartpole_sqil_config.py index ffff606c71..7e4553f729 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_sqil_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_sqil_config.py @@ -24,7 +24,7 @@ # Users should add their own model path here. Model path should lead to a model. # Absolute path is recommended. # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``. - model_path='model_path_placeholder' + model_path='cartpole_dqn_seed0/ckpt/eval.pth.tar' ), # note: this is the times after which you learns to evaluate eval=dict(evaluator=dict(eval_freq=50, )), diff --git a/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py b/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py index 724b9f64fa..306cadd6f2 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py @@ -7,7 +7,6 @@ evaluator_env_num=5, n_evaluator_episode=5, stop_value=195, - replay_path='cartpole_dqn/video', ), reward_model=dict( type='trex', @@ -20,12 +19,9 @@ update_per_collect=1, num_trajs=6, num_snippets=6000, - expert_model_path='abs model path', - reward_model_path='abs data path + ./cartpole.params', - data_path='abs data path', + expert_model_path='cartpole_dqn_seed0', # expert model experiment directory path ), policy=dict( - load_path='', cuda=False, model=dict( obs_shape=4, @@ -36,6 +32,7 @@ nstep=1, discount_factor=0.97, learn=dict( + update_per_collect=5, batch_size=64, learning_rate=0.001, ), diff --git a/dizoo/classic_control/pendulum/config/pendulum_sqil_sac_config.py b/dizoo/classic_control/pendulum/config/pendulum_sqil_sac_config.py index cecd4a057c..f80b85235f 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_sqil_sac_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_sqil_sac_config.py @@ -36,12 +36,13 @@ target_theta=0.005, discount_factor=0.99, auto_alpha=True, - value_network=False, ), collect=dict( n_sample=10, - model_path='model_path_placeholder', - unroll_len=1, + # Users should add their own model path here. Model path should lead to a model. + # Absolute path is recommended. + # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``. + model_path='pendulum_sac_seed0/ckpt/eval.pth.tar', ), eval=dict(evaluator=dict(eval_freq=100, )), other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ), From 6773d9762143a0e796ada99ab36f2c8e025f9524 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 4 Jan 2023 21:34:19 +0800 Subject: [PATCH 2/3] fix(nyz): fix unittest bugs --- ding/entry/tests/test_application_entry.py | 14 +++++++------- .../test_application_entry_trex_collect_data.py | 8 ++++---- ding/entry/tests/test_serial_entry_bc.py | 8 ++++---- .../test_serial_entry_preference_based_irl.py | 6 +++--- ding/entry/tests/test_serial_entry_reward_model.py | 6 +++--- ding/framework/middleware/ckpt_handler.py | 2 +- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/ding/entry/tests/test_application_entry.py b/ding/entry/tests/test_application_entry.py index 65a8924954..2b4e01a5e2 100644 --- a/ding/entry/tests/test_application_entry.py +++ b/ding/entry/tests/test_application_entry.py @@ -3,8 +3,8 @@ import os import pickle -from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config, \ - cartpole_offppo_create_config # noqa +from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \ + cartpole_ppo_offpolicy_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\ cartpole_trex_offppo_create_config from dizoo.classic_control.cartpole.envs import CartPoleEnv @@ -15,7 +15,7 @@ @pytest.fixture(scope='module') def setup_state_dict(): - config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config) + config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) try: policy = serial_pipeline(config, seed=0) except Exception: @@ -31,12 +31,12 @@ def setup_state_dict(): class TestApplication: def test_eval(self, setup_state_dict): - cfg_for_stop_value = compile_config(cartpole_offppo_config, auto=True, create_cfg=cartpole_offppo_create_config) + cfg_for_stop_value = compile_config(cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config) stop_value = cfg_for_stop_value.env.stop_value - config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config) + config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval']) assert episode_return >= stop_value - config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config) + config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) episode_return = eval( config, seed=0, @@ -46,7 +46,7 @@ def test_eval(self, setup_state_dict): assert episode_return >= stop_value def test_collect_demo_data(self, setup_state_dict): - config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config) + config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) collect_count = 16 expert_data_path = './expert.data' collect_demo_data( diff --git a/ding/entry/tests/test_application_entry_trex_collect_data.py b/ding/entry/tests/test_application_entry_trex_collect_data.py index 13e418b5f0..f5cb3c16b9 100644 --- a/ding/entry/tests/test_application_entry_trex_collect_data.py +++ b/ding/entry/tests/test_application_entry_trex_collect_data.py @@ -8,8 +8,8 @@ from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\ cartpole_trex_offppo_create_config -from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config,\ - cartpole_offppo_create_config +from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\ + cartpole_ppo_offpolicy_create_config from ding.entry.application_entry_trex_collect_data import collect_episodic_demo_data_for_trex, trex_collecting_data from ding.entry import serial_pipeline @@ -18,7 +18,7 @@ def test_collect_episodic_demo_data_for_trex(): exp_name = "test_collect_episodic_demo_data_for_trex_expert" expert_policy_state_dict_path = os.path.join(exp_name, 'expert_policy.pth.tar') - config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)] + config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] config[0].exp_name = exp_name expert_policy = serial_pipeline(config, seed=0) torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) @@ -41,7 +41,7 @@ def test_collect_episodic_demo_data_for_trex(): @pytest.mark.unittest def test_trex_collecting_data(): expert_policy_dir = 'test_trex_collecting_data_expert' - config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)] + config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] config[0].exp_name = expert_policy_dir config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100 serial_pipeline(config, seed=0) diff --git a/ding/entry/tests/test_serial_entry_bc.py b/ding/entry/tests/test_serial_entry_bc.py index 4c2b1ae9cb..f2c0923ad2 100644 --- a/ding/entry/tests/test_serial_entry_bc.py +++ b/ding/entry/tests/test_serial_entry_bc.py @@ -14,7 +14,7 @@ from ding.utils import POLICY_REGISTRY from ding.utils.data import default_collate, default_decollate from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config, \ - cartpole_offppo_config, cartpole_offppo_create_config + cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config @@ -53,7 +53,7 @@ def _monitor_vars_learn(self) -> list: @pytest.mark.unittest def test_serial_pipeline_bc_ppo(): # train expert policy - train_config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)] + train_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] train_config[0].exp_name = 'test_serial_pipeline_bc_ppo' expert_policy = serial_pipeline(train_config, seed=0) @@ -61,14 +61,14 @@ def test_serial_pipeline_bc_ppo(): collect_count = 10000 expert_data_path = 'expert_data_ppo_bc.pkl' state_dict = expert_policy.collect_mode.state_dict() - collect_config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)] + collect_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] collect_config[0].exp_name = 'test_serial_pipeline_bc_ppo_collect' collect_demo_data( collect_config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count ) # il training 1 - il_config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)] + il_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] il_config[0].policy.eval.evaluator.multi_gpu = False il_config[0].policy.learn.train_epoch = 20 il_config[1].policy.type = 'ppo_bc' diff --git a/ding/entry/tests/test_serial_entry_preference_based_irl.py b/ding/entry/tests/test_serial_entry_preference_based_irl.py index 9718c01f0f..4ee679a6c0 100644 --- a/ding/entry/tests/test_serial_entry_preference_based_irl.py +++ b/ding/entry/tests/test_serial_entry_preference_based_irl.py @@ -9,8 +9,8 @@ from ding.entry import serial_pipeline_preference_based_irl from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\ cartpole_trex_offppo_create_config -from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config,\ - cartpole_offppo_create_config +from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\ + cartpole_ppo_offpolicy_create_config from ding.entry.application_entry_trex_collect_data import trex_collecting_data from ding.reward_model.trex_reward_model import TrexConvEncoder from ding.torch_utils import is_differentiable @@ -19,7 +19,7 @@ @pytest.mark.unittest def test_serial_pipeline_trex(): exp_name = 'test_serial_pipeline_trex_expert' - config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)] + config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100 config[0].exp_name = exp_name expert_policy = serial_pipeline(config, seed=0) diff --git a/ding/entry/tests/test_serial_entry_reward_model.py b/ding/entry/tests/test_serial_entry_reward_model.py index a81c5a4b6e..404cb6d78c 100644 --- a/ding/entry/tests/test_serial_entry_reward_model.py +++ b/ding/entry/tests/test_serial_entry_reward_model.py @@ -5,7 +5,7 @@ from copy import deepcopy from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config -from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config, cartpole_offppo_create_config # noqa +from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \ @@ -44,13 +44,13 @@ @pytest.mark.parametrize('reward_model_config', cfg) def test_irl(reward_model_config): reward_model_config = EasyDict(reward_model_config) - config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config) + config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) expert_policy = serial_pipeline(config, seed=0, max_train_iter=2) # collect expert demo data collect_count = 10000 expert_data_path = 'expert_data.pkl' state_dict = expert_policy.collect_mode.state_dict() - config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config) + config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) collect_demo_data( config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count ) diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index 2d71818576..3502d9c50f 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -54,7 +54,7 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: - eval_value (:obj:`float`): The episode return of current iteration. """ # train enough iteration - if self.train_freq : + if self.train_freq: if ctx.train_iter == 0 or ctx.train_iter - self.last_save_iter >= self.train_freq: save_file( "{}/iteration_{}.pth.tar".format(self.prefix, ctx.train_iter), self.policy.learn_mode.state_dict() From 163ed01dda84fdb634540ed837559c8b01beb094 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 4 Jan 2023 22:43:49 +0800 Subject: [PATCH 3/3] fix(nyz): fix trex unittest bugs --- ding/entry/serial_entry_preference_based_irl.py | 2 +- ding/entry/serial_entry_preference_based_irl_onpolicy.py | 2 +- ding/entry/tests/test_application_entry.py | 4 +++- ding/entry/tests/test_serial_entry_preference_based_irl.py | 2 -- .../tests/test_serial_entry_preference_based_irl_onpolicy.py | 2 -- ding/framework/middleware/tests/test_ckpt_handler.py | 2 +- 6 files changed, 6 insertions(+), 8 deletions(-) diff --git a/ding/entry/serial_entry_preference_based_irl.py b/ding/entry/serial_entry_preference_based_irl.py index 8ceaf52977..682e662baa 100644 --- a/ding/entry/serial_entry_preference_based_irl.py +++ b/ding/entry/serial_entry_preference_based_irl.py @@ -47,7 +47,7 @@ def serial_pipeline_preference_based_irl( create_cfg.policy.type = create_cfg.policy.type + '_command' create_cfg.reward_model = dict(type=cfg.reward_model.type) env_fn = None if env_setting is None else env_setting[0] - cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=False) cfg_bak = copy.deepcopy(cfg) # Create main components: env, policy if env_setting is None: diff --git a/ding/entry/serial_entry_preference_based_irl_onpolicy.py b/ding/entry/serial_entry_preference_based_irl_onpolicy.py index 3bc1a8875b..3941f3337e 100644 --- a/ding/entry/serial_entry_preference_based_irl_onpolicy.py +++ b/ding/entry/serial_entry_preference_based_irl_onpolicy.py @@ -46,7 +46,7 @@ def serial_pipeline_preference_based_irl_onpolicy( create_cfg.policy.type = create_cfg.policy.type + '_command' create_cfg.reward_model = dict(type=cfg.reward_model.type) env_fn = None if env_setting is None else env_setting[0] - cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=False) # Create main components: env, policy if env_setting is None: env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) diff --git a/ding/entry/tests/test_application_entry.py b/ding/entry/tests/test_application_entry.py index 2b4e01a5e2..9276d9e6e5 100644 --- a/ding/entry/tests/test_application_entry.py +++ b/ding/entry/tests/test_application_entry.py @@ -31,7 +31,9 @@ def setup_state_dict(): class TestApplication: def test_eval(self, setup_state_dict): - cfg_for_stop_value = compile_config(cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config) + cfg_for_stop_value = compile_config( + cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config + ) stop_value = cfg_for_stop_value.env.stop_value config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval']) diff --git a/ding/entry/tests/test_serial_entry_preference_based_irl.py b/ding/entry/tests/test_serial_entry_preference_based_irl.py index 4ee679a6c0..7e9198f929 100644 --- a/ding/entry/tests/test_serial_entry_preference_based_irl.py +++ b/ding/entry/tests/test_serial_entry_preference_based_irl.py @@ -27,8 +27,6 @@ def test_serial_pipeline_trex(): exp_name = 'test_serial_pipeline_trex_collect' config = [deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)] config[0].exp_name = exp_name - config[0].reward_model.data_path = exp_name - config[0].reward_model.reward_model_path = exp_name + '/cartpole.params' config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_expert' config[0].reward_model.checkpoint_max = 100 config[0].reward_model.checkpoint_step = 100 diff --git a/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py b/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py index dfeee9b4e4..ff0e88b0d5 100644 --- a/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py +++ b/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py @@ -24,8 +24,6 @@ def test_serial_pipeline_trex_onpolicy(): exp_name = 'test_serial_pipeline_trex_onpolicy_collect' config = [deepcopy(cartpole_trex_ppo_onpolicy_config), deepcopy(cartpole_trex_ppo_onpolicy_create_config)] config[0].exp_name = exp_name - config[0].reward_model.data_path = exp_name - config[0].reward_model.reward_model_path = exp_name + '/cartpole.params' config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_onpolicy_expert' config[0].reward_model.checkpoint_max = 100 config[0].reward_model.checkpoint_step = 100 diff --git a/ding/framework/middleware/tests/test_ckpt_handler.py b/ding/framework/middleware/tests/test_ckpt_handler.py index 7c8c95c99b..f0d81f0a33 100644 --- a/ding/framework/middleware/tests/test_ckpt_handler.py +++ b/ding/framework/middleware/tests/test_ckpt_handler.py @@ -52,7 +52,7 @@ def mock_save_file(path, data, fs_type=None, use_lock=False): assert path == "{}/eval.pth.tar".format(prefix) with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): - ctx.train_iter = 0 + ctx.train_iter = 1 ctx.eval_value = 9.4 ckpt_saver = CkptSaver(policy, exp_name, train_freq) ckpt_saver(ctx)