Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add more DingEnvWrapper example #525

Merged
merged 7 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions dizoo/dmc2gym/config/dmc2gym_ppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from easydict import EasyDict


cartpole_balance_ppo_config = dict(
exp_name='dmc2gym_cartpole_balance_ppo',
env=dict(
env_id='dmc2gym_cartpole_balance',
domain_name='cartpole',
task_name='balance',
from_pixels=False,
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=1000,
),
policy=dict(
cuda=True,
recompute_adv=True,
action_space='discrete',
model=dict(
obs_shape=5,
action_shape=1,
action_space='discrete',
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
),
learn=dict(
epoch_per_collect=2,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(hook=dict(save_ckpt_after_iter=100)),
),
collect=dict(
n_sample=256,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
),
other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
)
)
cartpole_balance_ppo_config = EasyDict(cartpole_balance_ppo_config)
main_config = cartpole_balance_ppo_config

cartpole_balance_create_config = dict(
env=dict(
type='dmc2gym',
import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
replay_buffer=dict(type='naive', ),
)
cartpole_balance_create_config = EasyDict(cartpole_balance_create_config)
create_config = cartpole_balance_create_config
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved

# To use this config, you can enter dizoo/dmc2gym/entry to call dmc2gym_onppo_main.py
122 changes: 122 additions & 0 deletions dizoo/dmc2gym/entry/dmc2gym_onppo_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
from easydict import EasyDict
from functools import partial
from tensorboardX import SummaryWriter
import dmc2gym

from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
from ding.model import VAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, FinalEvalRewardEnv, BaseEnvManager
from ding.config import compile_config
from ding.utils import set_pkg_seed
from dizoo.dmc2gym.config.dmc2gym_ppo_config import cartpole_balance_ppo_config
from dizoo.dmc2gym.envs.dmc2gym_env import *


class Dmc2GymWrapper(gym.Wrapper):
def __init__(self, env, cfg):
super().__init__(env)
cfg = EasyDict(cfg)
self._cfg = cfg

env_info = dmc2gym_env_info[cfg.domain_name][cfg.task_name]

self._observation_space = env_info["observation_space"](
from_pixels=self._cfg["from_pixels"],
height=self._cfg["height"],
width=self._cfg["width"],
channels_first=self._cfg["channels_first"]
)
self._action_space = env_info["action_space"]
self._reward_space = env_info["reward_space"](self._cfg["frame_skip"])

def _process_obs(self, obs):
if self._cfg["from_pixels"]:
obs = to_ndarray(obs).astype(np.uint8)
else:
obs = to_ndarray(obs).astype(np.float32)
return obs

def step(self, action):
action = np.array([action]).astype('float32')
obs, reward, done, info = self.env.step(action)
return self._process_obs(obs), reward, done, info

def reset(self):
obs = self.env.reset()
return self._process_obs(obs)


def wrapped_dmc2gym_env(cfg):
default_cfg = {
"frame_skip": 3,
"from_pixels": True,
"visualize_reward": False,
"height": 100,
"width": 100,
"channels_first": True,
}
default_cfg.update(cfg)


return DingEnvWrapper(
dmc2gym.make(
domain_name=default_cfg["domain_name"],
task_name=default_cfg["task_name"],
seed=1,
visualize_reward=default_cfg["visualize_reward"],
from_pixels=default_cfg["from_pixels"],
height=default_cfg["height"],
width=default_cfg["width"],
frame_skip=default_cfg["frame_skip"]
)
,
cfg={
'env_wrapper': [
lambda env: Dmc2GymWrapper(env, default_cfg),
lambda env: FinalEvalRewardEnv(env),
]
}
)


def main(cfg, seed=0, max_env_step=int(1e10), max_train_iter=int(1e10)):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[partial(wrapped_dmc2gym_env, cfg=cartpole_balance_ppo_config.env)
for _ in range(collector_env_num)], cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[partial(wrapped_dmc2gym_env, cfg=cartpole_balance_ppo_config.env)
for _ in range(evaluator_env_num)], cfg=cfg.env.manager)

collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

while True:
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break


if __name__ == '__main__':
main(cartpole_balance_ppo_config)
94 changes: 94 additions & 0 deletions dizoo/minigrid/entry/minigrid_onppo_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import gym
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
from ding.model import VAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, FinalEvalRewardEnv, BaseEnvManager
from ding.config import compile_config
from ding.utils import set_pkg_seed

from dizoo.minigrid.config.minigrid_onppo_config import minigrid_ppo_config
from minigrid.wrappers import FlatObsWrapper
import numpy as np
from tensorboardX import SummaryWriter
import os
import gymnasium


class MinigridWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self._observation_space = gym.spaces.Box(
low=float("-inf"),
high=float("inf"),
shape=(8,),
dtype=np.float32
)
self._action_space = gym.spaces.Discrete(9)
self._action_space.seed(0) # default seed
self.reward_range = (float('-inf'), float('inf'))
self.max_steps = minigrid_ppo_config.env.max_step

def step(self, action):
obs, reward, done, _, info = self.env.step(action)
self.cur_step += 1
if self.cur_step > self.max_steps:
done = True
return obs, reward, done, info

def reset(self):
self.cur_step = 0
return self.env.reset()[0]


def wrapped_minigrid_env():
return DingEnvWrapper(
gymnasium.make(minigrid_ppo_config.env.env_id),
cfg={
'env_wrapper': [
lambda env: FlatObsWrapper(env),
lambda env: MinigridWrapper(env),
lambda env: FinalEvalRewardEnv(env),
]
}
)


def main(cfg, seed=0, max_env_step=int(1e10), max_train_iter=int(1e10)):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(collector_env_num)],
cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(evaluator_env_num)],
cfg=cfg.env.manager)

collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

while True:
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break


if __name__ == '__main__':
main(minigrid_ppo_config)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

bigfish_plr_default_config = dict(
bigfish_plr_config = dict(
exp_name='bigfish_plr_seed1',
env=dict(
is_train=True,
Expand Down Expand Up @@ -42,8 +42,8 @@
temperature=0.1,
),
)
bigfish_plr_default_config = EasyDict(bigfish_plr_default_config)
main_config = bigfish_plr_default_config
bigfish_plr_config = EasyDict(bigfish_plr_config)
main_config = bigfish_plr_config

bigfish_plr_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

bigfish_ppg_default_config = dict(
bigfish_ppg_config = dict(
exp_name='bigfish_ppg_seed0',
env=dict(
is_train=True,
Expand Down Expand Up @@ -37,8 +37,8 @@
other=dict(),
),
)
bigfish_ppg_default_config = EasyDict(bigfish_ppg_default_config)
main_config = bigfish_ppg_default_config
bigfish_ppg_config = EasyDict(bigfish_ppg_config)
main_config = bigfish_ppg_config

bigfish_ppg_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

coinrun_dqn_default_config = dict(
coinrun_dqn_config = dict(
env=dict(
env_id='coinrun',
collector_env_num=4,
Expand Down Expand Up @@ -36,8 +36,8 @@
),
),
)
coinrun_dqn_default_config = EasyDict(coinrun_dqn_default_config)
main_config = coinrun_dqn_default_config
coinrun_dqn_config = EasyDict(coinrun_dqn_config)
main_config = coinrun_dqn_config

coinrun_dqn_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

coinrun_ppg_default_config = dict(
coinrun_ppg_config = dict(
exp_name='coinrun_ppg_seed0',
env=dict(
is_train=True,
Expand Down Expand Up @@ -37,8 +37,8 @@
other=dict(),
),
)
coinrun_ppg_default_config = EasyDict(coinrun_ppg_default_config)
main_config = coinrun_ppg_default_config
coinrun_ppg_config = EasyDict(coinrun_ppg_config)
main_config = coinrun_ppg_config

coinrun_ppg_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

coinrun_ppo_default_config = dict(
coinrun_ppo_config = dict(
env=dict(
is_train=True,
env_id='coinrun',
Expand Down Expand Up @@ -39,8 +39,8 @@
),
),
)
coinrun_ppo_default_config = EasyDict(coinrun_ppo_default_config)
main_config = coinrun_ppo_default_config
coinrun_ppo_config = EasyDict(coinrun_ppo_config)
main_config = coinrun_ppo_config

coinrun_ppo_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

maze_dqn_default_config = dict(
maze_dqn_config = dict(
env=dict(
collector_env_num=4,
env_id='maze',
Expand Down Expand Up @@ -37,8 +37,8 @@
),
),
)
maze_dqn_default_config = EasyDict(maze_dqn_default_config)
main_config = maze_dqn_default_config
maze_dqn_config = EasyDict(maze_dqn_config)
main_config = maze_dqn_config

maze_dqn_create_config = dict(
env=dict(
Expand Down
Loading