Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev(lwq): add continuous examples: ddpg, td3 and d4pg #384

Merged
merged 2 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions ding/example/d4pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import gym
from ditk import logging
from ding.model.template.qac_dist import QACDIST
from ding.policy import D4PGPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.data.buffer.middleware import PriorityExperienceReplay
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
CkptSaver, nstep_reward_enhancer
from ding.utils import set_pkg_seed
from dizoo.classic_control.pendulum.config.pendulum_d4pg_config import main_config, create_config


def main():

logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("Pendulum-v0")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("Pendulum-v0")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = QACDIST(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
buffer_.use(PriorityExperienceReplay(buffer_, IS_weight=True))
policy = D4PGPolicy(cfg.policy, model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(
StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
)
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.run()


if __name__ == "__main__":
main()
47 changes: 47 additions & 0 deletions ding/example/ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import gym
from ditk import logging
from ding.model.template.qac import QAC
from ding.policy import DDPGPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
CkptSaver, termination_checker
from ding.utils import set_pkg_seed
from dizoo.classic_control.pendulum.config.pendulum_ddpg_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("Pendulum-v0")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("Pendulum-v0")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = QAC(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = DDPGPolicy(cfg.policy, model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(
StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
)
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.use(termination_checker(max_train_iter=10000))
task.run()


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions ding/example/td3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import gym
from ditk import logging
from ding.model.template.qac import QAC
from ding.policy import TD3Policy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, CkptSaver
from ding.utils import set_pkg_seed
from dizoo.classic_control.pendulum.config.pendulum_td3_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("Pendulum-v0")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("Pendulum-v0")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = QAC(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = TD3Policy(cfg.policy, model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(
StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
)
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.run()


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
cuda=False,
priority=True,
nstep=3,
discount_factor=0.99,
random_collect_size=800,
model=dict(
obs_shape=3,
Expand Down