-
Notifications
You must be signed in to change notification settings - Fork 381
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(pu): add ddp config of dqn and onppo (#842)
* feature(pu): add pong and cartpole ddp config of dqn and onppo * fix(pu):fix atari_ppo_ddp.py * polish(pu): polish atari_dqn_ddp.py and atari_ppo_ddp.py * polish(pu): polish atari ddp configs
- Loading branch information
1 parent
580ea65
commit 9a6e46f
Showing
11 changed files
with
304 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from easydict import EasyDict | ||
|
||
pong_dqn_config = dict( | ||
exp_name='data_pong/pong_dqn_ddp_seed0', | ||
env=dict( | ||
collector_env_num=4, | ||
evaluator_env_num=4, | ||
n_evaluator_episode=8, | ||
stop_value=20, | ||
env_id='PongNoFrameskip-v4', | ||
#'ALE/Pong-v5' is available. But special setting is needed after gym make. | ||
frame_stack=4, | ||
), | ||
policy=dict( | ||
multi_gpu=True, | ||
cuda=True, | ||
priority=False, | ||
model=dict( | ||
obs_shape=[4, 84, 84], | ||
action_shape=6, | ||
encoder_hidden_size_list=[128, 128, 512], | ||
), | ||
nstep=3, | ||
discount_factor=0.99, | ||
learn=dict( | ||
update_per_collect=10, | ||
batch_size=32, | ||
learning_rate=0.0001, | ||
target_update_freq=500, | ||
), | ||
collect=dict(n_sample=96, ), | ||
eval=dict(evaluator=dict(eval_freq=4000, )), | ||
other=dict( | ||
eps=dict( | ||
type='exp', | ||
start=1., | ||
end=0.05, | ||
decay=250000, | ||
), | ||
replay_buffer=dict(replay_buffer_size=100000, ), | ||
), | ||
), | ||
) | ||
pong_dqn_config = EasyDict(pong_dqn_config) | ||
main_config = pong_dqn_config | ||
pong_dqn_create_config = dict( | ||
env=dict( | ||
type='atari', | ||
import_names=['dizoo.atari.envs.atari_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict(type='dqn'), | ||
) | ||
pong_dqn_create_config = EasyDict(pong_dqn_create_config) | ||
create_config = pong_dqn_create_config | ||
|
||
if __name__ == '__main__': | ||
""" | ||
Overview: | ||
This script should be executed with <nproc_per_node> GPUs. | ||
Run the following command to launch the script: | ||
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/config/serial/pong/pong_dqn_ddp_config.py | ||
""" | ||
from ding.utils import DDPContext | ||
from ding.entry import serial_pipeline | ||
with DDPContext(): | ||
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(3e6)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from easydict import EasyDict | ||
|
||
pong_ppo_config = dict( | ||
exp_name='data_pong/pong_ppo_ddp_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=8, | ||
n_evaluator_episode=8, | ||
stop_value=20, | ||
env_id='PongNoFrameskip-v4', | ||
#'ALE/Pong-v5' is available. But special setting is needed after gym make. | ||
frame_stack=4, | ||
), | ||
policy=dict( | ||
multi_gpu=True, | ||
cuda=True, | ||
recompute_adv=True, | ||
action_space='discrete', | ||
model=dict( | ||
obs_shape=[4, 84, 84], | ||
action_shape=6, | ||
action_space='discrete', | ||
encoder_hidden_size_list=[64, 64, 128], | ||
actor_head_hidden_size=128, | ||
critic_head_hidden_size=128, | ||
), | ||
learn=dict( | ||
epoch_per_collect=10, | ||
update_per_collect=1, | ||
batch_size=320, | ||
learning_rate=3e-4, | ||
value_weight=0.5, | ||
entropy_weight=0.001, | ||
clip_ratio=0.2, | ||
adv_norm=True, | ||
value_norm=True, | ||
# for ppo, when we recompute adv, we need the key done in data to split traj, so we must | ||
# use ignore_done=False here, | ||
# but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True | ||
# for halfcheetah, the length=1000 | ||
ignore_done=False, | ||
grad_clip_type='clip_norm', | ||
grad_clip_value=0.5, | ||
), | ||
collect=dict( | ||
n_sample=3200, | ||
unroll_len=1, | ||
discount_factor=0.99, | ||
gae_lambda=0.95, | ||
), | ||
eval=dict(evaluator=dict(eval_freq=1000, )), | ||
), | ||
) | ||
main_config = EasyDict(pong_ppo_config) | ||
|
||
pong_ppo_create_config = dict( | ||
env=dict( | ||
type='atari', | ||
import_names=['dizoo.atari.envs.atari_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict(type='ppo'), | ||
) | ||
create_config = EasyDict(pong_ppo_create_config) | ||
|
||
if __name__ == "__main__": | ||
""" | ||
Overview: | ||
This script should be executed with <nproc_per_node> GPUs. | ||
Run the following command to launch the script: | ||
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/config/serial/pong/pong_ppo_ddp_config.py | ||
""" | ||
from ding.utils import DDPContext | ||
from ding.entry import serial_pipeline_onpolicy | ||
with DDPContext(): | ||
serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(3e6)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from easydict import EasyDict | ||
|
||
cartpole_dqn_config = dict( | ||
exp_name='cartpole_dqn_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=5, | ||
n_evaluator_episode=5, | ||
stop_value=195, | ||
replay_path='cartpole_dqn_seed0/video', | ||
), | ||
policy=dict( | ||
multi_gpu=True, | ||
cuda=True, | ||
model=dict( | ||
obs_shape=4, | ||
action_shape=2, | ||
encoder_hidden_size_list=[128, 128, 64], | ||
dueling=True, | ||
# dropout=0.1, | ||
), | ||
nstep=1, | ||
discount_factor=0.97, | ||
learn=dict( | ||
update_per_collect=5, | ||
batch_size=64, | ||
learning_rate=0.001, | ||
), | ||
collect=dict(n_sample=8), | ||
eval=dict(evaluator=dict(eval_freq=40, )), | ||
other=dict( | ||
eps=dict( | ||
type='exp', | ||
start=0.95, | ||
end=0.1, | ||
decay=10000, | ||
), | ||
replay_buffer=dict(replay_buffer_size=20000, ), | ||
), | ||
), | ||
) | ||
cartpole_dqn_config = EasyDict(cartpole_dqn_config) | ||
main_config = cartpole_dqn_config | ||
cartpole_dqn_create_config = dict( | ||
env=dict( | ||
type='cartpole', | ||
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict(type='dqn'), | ||
) | ||
cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config) | ||
create_config = cartpole_dqn_create_config | ||
|
||
if __name__ == "__main__": | ||
""" | ||
Overview: | ||
This script should be executed with <nproc_per_node> GPUs. | ||
Run the following command to launch the script: | ||
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py | ||
""" | ||
from ding.utils import DDPContext | ||
from ding.entry import serial_pipeline | ||
with DDPContext(): | ||
serial_pipeline((main_config, create_config), seed=0) | ||
|
Oops, something went wrong.