Skip to content

Commit

Permalink
change test for mdqn from asterix to cartpole because of platform tes…
Browse files Browse the repository at this point in the history
…t failed
  • Loading branch information
ruoyuGao committed Mar 5, 2023
1 parent b03b982 commit bdcd0ae
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 6 deletions.
8 changes: 4 additions & 4 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from dizoo.gym_hybrid.config.gym_hybrid_pdqn_config import gym_hybrid_pdqn_config, gym_hybrid_pdqn_create_config
from dizoo.gym_hybrid.config.gym_hybrid_mpdqn_config import gym_hybrid_mpdqn_config, gym_hybrid_mpdqn_create_config
from dizoo.classic_control.pendulum.config.pendulum_bdq_config import pendulum_bdq_config, pendulum_bdq_create_config # noqa
from dizoo.atari.config.serial.asterix.asterix_mdqn_config import asterix_mdqn_config, asterix_mdqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_mdqn_config import cartpole_mdqn_config, cartpole_mdqn_create_config


@pytest.mark.platformtest
Expand All @@ -72,15 +72,15 @@ def test_dqn():
@pytest.mark.platformtest
@pytest.mark.unittest
def test_mdqn():
config = [deepcopy(asterix_mdqn_config), deepcopy(asterix_mdqn_create_config)]
config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'asterix_mdqn_unittest'
config[0].exp_name = 'cartpole_mdqn_unittest'
try:
serial_pipeline(config, seed=0, max_train_iter=1, is_dynamic_seed=False)
except Exception:
assert False, "pipeline fail"
finally:
os.popen('rm -rf asterix_mdqn_unittest')
os.popen('rm -rf cartpole_mdqn_unittest')


@pytest.mark.platformtest
Expand Down
3 changes: 1 addition & 2 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import torch

from ding.torch_utils import Adam, to_device, ContrastiveLoss
from ding.rl_utils import q_nstep_td_data, m_q_1step_td_data,\
m_q_1step_td_error, q_nstep_td_error, get_nstep_return_data, get_train_sample
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
Expand Down
1 change: 1 addition & 0 deletions dizoo/classic_control/cartpole/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
from .cartpole_trex_dqn_config import cartpole_trex_dqn_config, cartpole_trex_dqn_create_config
from .cartpole_trex_offppo_config import cartpole_trex_offppo_config, cartpole_trex_offppo_create_config
from .cartpole_trex_onppo_config import cartpole_trex_ppo_onpolicy_config, cartpole_trex_ppo_onpolicy_create_config
from .cartpole_mdqn_config import cartpole_mdqn_config, cartpole_mdqn_create_config
# from .cartpole_ppo_default_loader import cartpole_ppo_default_loader
58 changes: 58 additions & 0 deletions dizoo/classic_control/cartpole/config/cartpole_mdqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from easydict import EasyDict

cartpole_mdqn_config = dict(
exp_name='cartpole_mdqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
policy=dict(
cuda=False,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=1,
discount_factor=0.97,
entropy_tau=0.03,
m_alpha=0.9,
learn=dict(
update_per_collect=5,
batch_size=64,
learning_rate=0.001,
),
collect=dict(n_sample=8),
eval=dict(evaluator=dict(eval_freq=40, )),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
cartpole_mdqn_config = EasyDict(cartpole_mdqn_config)
main_config = cartpole_mdqn_config
cartpole_mdqn_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base'),
policy=dict(type='mdqn'),
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
)
cartpole_mdqn_create_config = EasyDict(cartpole_mdqn_create_config)
create_config = cartpole_mdqn_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial -c cartpole_dqn_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, is_dynamic_seed=False)

0 comments on commit bdcd0ae

Please sign in to comment.