Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish(nyz): polish example demos #568

Merged
merged 3 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ding/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def compile_config(
create_cfg: dict = None,
save_cfg: bool = True,
save_path: str = 'total_config.py',
renew_dir: bool = True,
) -> EasyDict:
"""
Overview:
Expand All @@ -361,6 +362,7 @@ def compile_config(
- create_cfg (:obj:`dict`): Input create config dict
- save_cfg (:obj:`bool`): Save config or not
- save_path (:obj:`str`): Path of saving file
- renew_dir (:obj:`bool`): Whether to new a directory for saving config.
Returns:
- cfg (:obj:`EasyDict`): Config after compiling
"""
Expand Down Expand Up @@ -460,7 +462,7 @@ def compile_config(
if 'exp_name' not in cfg:
cfg.exp_name = 'default_experiment'
if save_cfg:
if os.path.exists(cfg.exp_name):
if os.path.exists(cfg.exp_name) and renew_dir:
cfg.exp_name += datetime.datetime.now().strftime("_%y%m%d_%H%M%S")
try:
os.makedirs(cfg.exp_name)
Expand Down
2 changes: 1 addition & 1 deletion ding/entry/serial_entry_preference_based_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def serial_pipeline_preference_based_irl(
create_cfg.policy.type = create_cfg.policy.type + '_command'
create_cfg.reward_model = dict(type=cfg.reward_model.type)
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=False)
cfg_bak = copy.deepcopy(cfg)
# Create main components: env, policy
if env_setting is None:
Expand Down
2 changes: 1 addition & 1 deletion ding/entry/serial_entry_preference_based_irl_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def serial_pipeline_preference_based_irl_onpolicy(
create_cfg.policy.type = create_cfg.policy.type + '_command'
create_cfg.reward_model = dict(type=cfg.reward_model.type)
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=False)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down
16 changes: 9 additions & 7 deletions ding/entry/tests/test_application_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
import pickle

from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config, \
cartpole_offppo_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \
cartpole_ppo_offpolicy_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
cartpole_trex_offppo_create_config
from dizoo.classic_control.cartpole.envs import CartPoleEnv
Expand All @@ -15,7 +15,7 @@

@pytest.fixture(scope='module')
def setup_state_dict():
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
try:
policy = serial_pipeline(config, seed=0)
except Exception:
Expand All @@ -31,12 +31,14 @@ def setup_state_dict():
class TestApplication:

def test_eval(self, setup_state_dict):
cfg_for_stop_value = compile_config(cartpole_offppo_config, auto=True, create_cfg=cartpole_offppo_create_config)
cfg_for_stop_value = compile_config(
cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config
)
stop_value = cfg_for_stop_value.env.stop_value
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval'])
assert episode_return >= stop_value
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
episode_return = eval(
config,
seed=0,
Expand All @@ -46,7 +48,7 @@ def test_eval(self, setup_state_dict):
assert episode_return >= stop_value

def test_collect_demo_data(self, setup_state_dict):
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
collect_count = 16
expert_data_path = './expert.data'
collect_demo_data(
Expand Down
8 changes: 4 additions & 4 deletions ding/entry/tests/test_application_entry_trex_collect_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
cartpole_trex_offppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config,\
cartpole_offppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\
cartpole_ppo_offpolicy_create_config
from ding.entry.application_entry_trex_collect_data import collect_episodic_demo_data_for_trex, trex_collecting_data
from ding.entry import serial_pipeline

Expand All @@ -18,7 +18,7 @@
def test_collect_episodic_demo_data_for_trex():
exp_name = "test_collect_episodic_demo_data_for_trex_expert"
expert_policy_state_dict_path = os.path.join(exp_name, 'expert_policy.pth.tar')
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
config[0].exp_name = exp_name
expert_policy = serial_pipeline(config, seed=0)
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
Expand All @@ -41,7 +41,7 @@ def test_collect_episodic_demo_data_for_trex():
@pytest.mark.unittest
def test_trex_collecting_data():
expert_policy_dir = 'test_trex_collecting_data_expert'
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
config[0].exp_name = expert_policy_dir
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
serial_pipeline(config, seed=0)
Expand Down
8 changes: 4 additions & 4 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from dizoo.classic_control.cartpole.config.cartpole_dqn_stdim_config import cartpole_dqn_stdim_config, \
cartpole_dqn_stdim_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config, \
cartpole_offppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \
cartpole_ppo_offpolicy_create_config
from dizoo.classic_control.cartpole.config.cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config # noqa
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_qrdqn():
@pytest.mark.platformtest
@pytest.mark.unittest
def test_ppo():
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'ppo_offpolicy_unittest'
try:
Expand All @@ -221,7 +221,7 @@ def test_ppo():
@pytest.mark.platformtest
@pytest.mark.unittest
def test_ppo_nstep_return():
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].policy.nstep_return = True
try:
Expand Down
8 changes: 4 additions & 4 deletions ding/entry/tests/test_serial_entry_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config, \
cartpole_offppo_config, cartpole_offppo_create_config
cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config
from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config


Expand Down Expand Up @@ -53,22 +53,22 @@ def _monitor_vars_learn(self) -> list:
@pytest.mark.unittest
def test_serial_pipeline_bc_ppo():
# train expert policy
train_config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
train_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
train_config[0].exp_name = 'test_serial_pipeline_bc_ppo'
expert_policy = serial_pipeline(train_config, seed=0)

# collect expert demo data
collect_count = 10000
expert_data_path = 'expert_data_ppo_bc.pkl'
state_dict = expert_policy.collect_mode.state_dict()
collect_config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
collect_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
collect_config[0].exp_name = 'test_serial_pipeline_bc_ppo_collect'
collect_demo_data(
collect_config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
)

# il training 1
il_config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
il_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
il_config[0].policy.eval.evaluator.multi_gpu = False
il_config[0].policy.learn.train_epoch = 20
il_config[1].policy.type = 'ppo_bc'
Expand Down
8 changes: 3 additions & 5 deletions ding/entry/tests/test_serial_entry_preference_based_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from ding.entry import serial_pipeline_preference_based_irl
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
cartpole_trex_offppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config,\
cartpole_offppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\
cartpole_ppo_offpolicy_create_config
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
from ding.reward_model.trex_reward_model import TrexConvEncoder
from ding.torch_utils import is_differentiable
Expand All @@ -19,16 +19,14 @@
@pytest.mark.unittest
def test_serial_pipeline_trex():
exp_name = 'test_serial_pipeline_trex_expert'
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
config[0].exp_name = exp_name
expert_policy = serial_pipeline(config, seed=0)

exp_name = 'test_serial_pipeline_trex_collect'
config = [deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)]
config[0].exp_name = exp_name
config[0].reward_model.data_path = exp_name
config[0].reward_model.reward_model_path = exp_name + '/cartpole.params'
config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_expert'
config[0].reward_model.checkpoint_max = 100
config[0].reward_model.checkpoint_step = 100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def test_serial_pipeline_trex_onpolicy():
exp_name = 'test_serial_pipeline_trex_onpolicy_collect'
config = [deepcopy(cartpole_trex_ppo_onpolicy_config), deepcopy(cartpole_trex_ppo_onpolicy_create_config)]
config[0].exp_name = exp_name
config[0].reward_model.data_path = exp_name
config[0].reward_model.reward_model_path = exp_name + '/cartpole.params'
config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_onpolicy_expert'
config[0].reward_model.checkpoint_max = 100
config[0].reward_model.checkpoint_step = 100
Expand Down
6 changes: 3 additions & 3 deletions ding/entry/tests/test_serial_entry_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy

from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config, cartpole_offppo_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
Expand Down Expand Up @@ -44,13 +44,13 @@
@pytest.mark.parametrize('reward_model_config', cfg)
def test_irl(reward_model_config):
reward_model_config = EasyDict(reward_model_config)
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
expert_policy = serial_pipeline(config, seed=0, max_train_iter=2)
# collect expert demo data
collect_count = 10000
expert_data_path = 'expert_data.pkl'
state_dict = expert_policy.collect_mode.state_dict()
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
collect_demo_data(
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
)
Expand Down
2 changes: 1 addition & 1 deletion ding/example/c51_nstep.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main():
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.run()


Expand Down
2 changes: 1 addition & 1 deletion ding/example/collect_demo_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main():
policy.collect_mode.load_state_dict(state_dict)

task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(offline_data_saver(cfg, cfg.policy.collect.save_path, data_type='hdf5'))
task.use(offline_data_saver(cfg.policy.collect.save_path, data_type='hdf5'))
task.run(max_step=1)


Expand Down
2 changes: 1 addition & 1 deletion ding/example/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def main():
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(offline_logger())
task.run()

Expand Down
2 changes: 1 addition & 1 deletion ding/example/d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main():
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.run()


Expand Down
2 changes: 1 addition & 1 deletion ding/example/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main():
)
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(termination_checker(max_train_iter=10000))
task.run()

Expand Down
2 changes: 1 addition & 1 deletion ding/example/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def main():
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(online_logger(train_show_freq=10))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))

task.run()

Expand Down
2 changes: 1 addition & 1 deletion ding/example/dqn_her.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main():
task.use(EpisodeCollector(cfg, policy.collect_mode, collector_env))
task.use(data_pusher(cfg, buffer_))
task.use(HERLearner(cfg, policy.learn_mode, buffer_, her_reward_model))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.run()


Expand Down
2 changes: 1 addition & 1 deletion ding/example/dqn_new_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main():
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.run()


Expand Down
2 changes: 1 addition & 1 deletion ding/example/dqn_nstep.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main():
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(final_ctx_saver(cfg.exp_name))
task.run()

Expand Down
2 changes: 1 addition & 1 deletion ding/example/dqn_per.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.run()


Expand Down
2 changes: 1 addition & 1 deletion ding/example/dqn_rnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main():
task.use(trainer(cfg, reward_model))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_, reward_model=reward_model))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.run()


Expand Down
2 changes: 1 addition & 1 deletion ding/example/iqn_nstep.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main():
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.run()


Expand Down
2 changes: 1 addition & 1 deletion ding/example/pdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main():
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=1000))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.run()


Expand Down
8 changes: 4 additions & 4 deletions ding/example/ppg_offpolicy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gym
from ditk import logging
from ding.model import PPG
from ding.policy import PPGPolicy
from ding.policy import PPGOffPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.data.buffer.middleware import use_time_check, sample_range_view
Expand All @@ -11,7 +11,7 @@
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
CkptSaver, gae_estimator
from ding.utils import set_pkg_seed
from dizoo.classic_control.cartpole.config.cartpole_ppg_config import main_config, create_config
from dizoo.classic_control.cartpole.config.cartpole_ppg_offpolicy_config import main_config, create_config


def main():
Expand Down Expand Up @@ -39,13 +39,13 @@ def main():
value_buffer = buffer_.view()
value_buffer.use(use_time_check(value_buffer, max_use=buffer_cfg.value.max_use))
value_buffer.use(sample_range_view(value_buffer, start=-buffer_cfg.value.replay_buffer_size))
policy = PPGPolicy(cfg.policy, model=model)
policy = PPGOffPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(gae_estimator(cfg, policy.collect_mode, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, {'policy': policy_buffer, 'value': value_buffer}))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.run()


Expand Down
Loading