Skip to content

Commit

Permalink
feature(zym): update ppo config to support discrete action space (#809)
Browse files Browse the repository at this point in the history
* feat (zym): update ppo config to support discrete action space
  • Loading branch information
YinminZhang authored Jul 1, 2024
1 parent 35ec39e commit 7f95159
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 43 deletions.
47 changes: 27 additions & 20 deletions dizoo/atari/config/serial/enduro/enduro_onppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
enduro_onppo_config = dict(
exp_name='enduro_onppo_seed0',
env=dict(
collector_env_num=64,
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
Expand All @@ -14,38 +14,45 @@
),
policy=dict(
cuda=True,
recompute_adv=True,
action_space='discrete',
model=dict(
obs_shape=[4, 84, 84],
action_shape=9,
encoder_hidden_size_list=[32, 64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=2,
action_space='discrete',
encoder_hidden_size_list=[32, 64, 64, 512],
actor_head_layer_num=0,
critic_head_layer_num=0,
actor_head_hidden_size=512,
critic_head_hidden_size=512,
),
learn=dict(
update_per_collect=24,
batch_size=128,
# (bool) Whether to normalize advantage. Default to False.
adv_norm=False,
learning_rate=0.0001,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=1.0,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.001, # [0.1, 0.01 ,0.0]
clip_ratio=0.1
lr_scheduler=dict(epoch_num=5200, min_lr_lambda=0),
epoch_per_collect=4,
batch_size=256,
learning_rate=2.5e-4,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.1,
adv_norm=True,
value_norm=True,
# for onppo, when we recompute adv, we need the key done in data to split traj, so we must
# use ignore_done=False here,
# but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
# for halfcheetah, the length=1000
ignore_done=False,
grad_clip_type='clip_norm',
grad_clip_value=0.5,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
n_sample=1024,
unroll_len=1,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda=0.95,
discount_factor=0.99,
),
eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(replay_buffer=dict(
replay_buffer_size=10000,
max_use=3,
), ),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
main_config = EasyDict(enduro_onppo_config)
Expand Down
26 changes: 14 additions & 12 deletions dizoo/atari/config/serial/qbert/qbert_onppo_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from easydict import EasyDict

qbert_onppo_config = dict(
exp_name='enduro_onppo_seed0',
exp_name='qbert_onppo_seed0',
env=dict(
collector_env_num=16,
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=int(1e10),
Expand All @@ -19,18 +19,20 @@
obs_shape=[4, 84, 84],
action_shape=6,
action_space='discrete',
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
encoder_hidden_size_list=[32, 64, 64, 512],
actor_head_layer_num=0,
critic_head_layer_num=0,
actor_head_hidden_size=512,
critic_head_hidden_size=512,
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=5200, min_lr_lambda=0),
epoch_per_collect=4,
batch_size=256,
learning_rate=2.5e-4,
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
entropy_weight=0.01,
clip_ratio=0.1,
adv_norm=True,
value_norm=True,
# for onppo, when we recompute adv, we need the key done in data to split traj, so we must
Expand All @@ -42,7 +44,7 @@
grad_clip_value=0.5,
),
collect=dict(
n_sample=3200,
n_sample=1024,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
spaceinvaders_ppo_config = dict(
exp_name='spaceinvaders_onppo_seed0',
env=dict(
collector_env_num=16,
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=int(1e10),
Expand All @@ -21,18 +21,20 @@
obs_shape=[4, 84, 84],
action_shape=6,
action_space='discrete',
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
encoder_hidden_size_list=[32, 64, 64, 512],
actor_head_layer_num=0,
critic_head_layer_num=0,
actor_head_hidden_size=512,
critic_head_hidden_size=512,
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=5200, min_lr_lambda=0),
epoch_per_collect=4,
batch_size=256,
learning_rate=2.5e-4,
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
entropy_weight=0.01,
clip_ratio=0.1,
adv_norm=True,
value_norm=True,
# for onppo, when we recompute adv, we need the key done in data to split traj, so we must
Expand All @@ -44,7 +46,7 @@
grad_clip_value=0.5,
),
collect=dict(
n_sample=3200,
n_sample=1024,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
Expand Down

0 comments on commit 7f95159

Please sign in to comment.