Skip to content

Commit

Permalink
feature(pu): add ddp config of dqn and onppo (#842)
Browse files Browse the repository at this point in the history
* feature(pu): add pong and cartpole ddp config of dqn and onppo

* fix(pu):fix atari_ppo_ddp.py

* polish(pu): polish atari_dqn_ddp.py and  atari_ppo_ddp.py

* polish(pu): polish atari ddp configs
  • Loading branch information
puyuan1996 authored Dec 19, 2024
1 parent 580ea65 commit 9a6e46f
Show file tree
Hide file tree
Showing 11 changed files with 304 additions and 14 deletions.
4 changes: 2 additions & 2 deletions ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from ding.utils import set_pkg_seed, get_rank


def serial_pipeline_onpolicy(
Expand Down Expand Up @@ -68,7 +68,7 @@ def serial_pipeline_onpolicy(
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
Expand Down
6 changes: 5 additions & 1 deletion ding/worker/collector/interaction_serial_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def eval(
'''
# evaluator only work on rank0
stop_flag = False
episode_info = None # Initialize to ensure it's defined in all ranks

if get_rank() == 0:
if n_episode is None:
n_episode = self._default_n_episode
Expand Down Expand Up @@ -317,5 +319,7 @@ def eval(
broadcast_object_list(objects, src=0)
stop_flag, episode_info = objects

episode_info = to_item(episode_info)
# Ensure episode_info is converted to the correct format
episode_info = to_item(episode_info) if episode_info is not None else {}

return stop_flag, episode_info
2 changes: 1 addition & 1 deletion ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ding.envs import BaseEnvManager
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \
broadcast_object_list, allreduce_data
allreduce_data
from ding.torch_utils import to_tensor, to_ndarray
from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions

Expand Down
67 changes: 67 additions & 0 deletions dizoo/atari/config/serial/pong/pong_dqn_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from easydict import EasyDict

pong_dqn_config = dict(
exp_name='data_pong/pong_dqn_ddp_seed0',
env=dict(
collector_env_num=4,
evaluator_env_num=4,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
frame_stack=4,
),
policy=dict(
multi_gpu=True,
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=96, ),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
pong_dqn_config = EasyDict(pong_dqn_config)
main_config = pong_dqn_config
pong_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
pong_dqn_create_config = EasyDict(pong_dqn_create_config)
create_config = pong_dqn_create_config

if __name__ == '__main__':
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/config/serial/pong/pong_dqn_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline
with DDPContext():
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(3e6))
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

pong_onppo_config = dict(
pong_ppo_config = dict(
env=dict(
collector_env_num=8,
evaluator_env_num=8,
Expand Down Expand Up @@ -49,19 +49,19 @@
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
main_config = EasyDict(pong_onppo_config)
main_config = EasyDict(pong_ppo_config)

pong_onppo_create_config = dict(
pong_ppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
)
create_config = EasyDict(pong_onppo_create_config)
create_config = EasyDict(pong_ppo_create_config)

if __name__ == "__main__":
# or you can enter `ding -m serial_onpolicy -c pong_onppo_config.py -s 0`
# or you can enter `ding -m serial_onpolicy -c pong_ppo_config.py -s 0`
from ding.entry import serial_pipeline_onpolicy
serial_pipeline_onpolicy((main_config, create_config), seed=0)
76 changes: 76 additions & 0 deletions dizoo/atari/config/serial/pong/pong_ppo_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from easydict import EasyDict

pong_ppo_config = dict(
exp_name='data_pong/pong_ppo_ddp_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
frame_stack=4,
),
policy=dict(
multi_gpu=True,
cuda=True,
recompute_adv=True,
action_space='discrete',
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
action_space='discrete',
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
# for ppo, when we recompute adv, we need the key done in data to split traj, so we must
# use ignore_done=False here,
# but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
# for halfcheetah, the length=1000
ignore_done=False,
grad_clip_type='clip_norm',
grad_clip_value=0.5,
),
collect=dict(
n_sample=3200,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=1000, )),
),
)
main_config = EasyDict(pong_ppo_config)

pong_ppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
)
create_config = EasyDict(pong_ppo_create_config)

if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/config/serial/pong/pong_ppo_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline_onpolicy
with DDPContext():
serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(3e6))
6 changes: 6 additions & 0 deletions dizoo/atari/example/atari_dqn_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,10 @@ def main():


if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/example/atari_dqn_ddp.py
"""
main()
2 changes: 1 addition & 1 deletion dizoo/atari/example/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
gae_estimator, termination_checker
from ding.utils import set_pkg_seed
from dizoo.atari.envs.atari_env import AtariEnv
from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
from dizoo.atari.config.serial.pong.pong_ppo_config import main_config, create_config


def main():
Expand Down
15 changes: 11 additions & 4 deletions dizoo/atari/example/atari_ppo_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
gae_estimator, ddp_termination_checker, online_logger
from ding.utils import set_pkg_seed, DistContext, get_rank, get_world_size
from ding.utils import set_pkg_seed, DDPContext, get_rank, get_world_size
from dizoo.atari.envs.atari_env import AtariEnv
from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
from dizoo.atari.config.serial.pong.pong_ppo_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
with DistContext():
with DDPContext():
rank, world_size = get_rank(), get_world_size()
main_config.example = 'pong_ppo_seed0_ddp_avgsplit'
main_config.policy.multi_gpu = True
Expand Down Expand Up @@ -45,12 +45,19 @@ def main():
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(gae_estimator(cfg, policy.collect_mode))
task.use(multistep_trainer(cfg, policy.learn_mode))
task.use(multistep_trainer(policy.learn_mode))
if rank == 0:
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(online_logger(record_train_iter=True))
task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank))
task.run()


if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/example/atari_ppo_ddp.py
"""
main()
66 changes: 66 additions & 0 deletions dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from easydict import EasyDict

cartpole_dqn_config = dict(
exp_name='cartpole_dqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
replay_path='cartpole_dqn_seed0/video',
),
policy=dict(
multi_gpu=True,
cuda=True,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
# dropout=0.1,
),
nstep=1,
discount_factor=0.97,
learn=dict(
update_per_collect=5,
batch_size=64,
learning_rate=0.001,
),
collect=dict(n_sample=8),
eval=dict(evaluator=dict(eval_freq=40, )),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
cartpole_dqn_config = EasyDict(cartpole_dqn_config)
main_config = cartpole_dqn_config
cartpole_dqn_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config)
create_config = cartpole_dqn_create_config

if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline
with DDPContext():
serial_pipeline((main_config, create_config), seed=0)

Loading

0 comments on commit 9a6e46f

Please sign in to comment.