Skip to content

Commit

Permalink
feature(whl): add more DingEnvWrapper example (#525)
Browse files Browse the repository at this point in the history
* wrap env for rocket and minigrid

* wrap env for rocket and minigrid

* add procgen & dmc2gym

* polish pr according to Pu

* polish code
  • Loading branch information
kxzxvbk authored Nov 15, 2022
1 parent 721e671 commit 049acb3
Show file tree
Hide file tree
Showing 15 changed files with 601 additions and 24 deletions.
64 changes: 64 additions & 0 deletions dizoo/dmc2gym/config/dmc2gym_ppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from easydict import EasyDict


cartpole_balance_ppo_config = dict(
exp_name='dmc2gym_cartpole_balance_ppo',
env=dict(
env_id='dmc2gym_cartpole_balance',
domain_name='cartpole',
task_name='balance',
from_pixels=False,
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=1000,
),
policy=dict(
cuda=True,
recompute_adv=True,
action_space='discrete',
model=dict(
obs_shape=5,
action_shape=1,
action_space='discrete',
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
),
learn=dict(
epoch_per_collect=2,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(hook=dict(save_ckpt_after_iter=100)),
),
collect=dict(
n_sample=256,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
),
other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
)
)
cartpole_balance_ppo_config = EasyDict(cartpole_balance_ppo_config)
main_config = cartpole_balance_ppo_config

cartpole_balance_create_config = dict(
env=dict(
type='dmc2gym',
import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
replay_buffer=dict(type='naive', ),
)
cartpole_balance_create_config = EasyDict(cartpole_balance_create_config)
create_config = cartpole_balance_create_config

# To use this config, you can enter dizoo/dmc2gym/entry to call dmc2gym_onppo_main.py
122 changes: 122 additions & 0 deletions dizoo/dmc2gym/entry/dmc2gym_onppo_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
from easydict import EasyDict
from functools import partial
from tensorboardX import SummaryWriter
import dmc2gym

from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
from ding.model import VAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, FinalEvalRewardEnv, BaseEnvManager
from ding.config import compile_config
from ding.utils import set_pkg_seed
from dizoo.dmc2gym.config.dmc2gym_ppo_config import cartpole_balance_ppo_config
from dizoo.dmc2gym.envs.dmc2gym_env import *


class Dmc2GymWrapper(gym.Wrapper):
def __init__(self, env, cfg):
super().__init__(env)
cfg = EasyDict(cfg)
self._cfg = cfg

env_info = dmc2gym_env_info[cfg.domain_name][cfg.task_name]

self._observation_space = env_info["observation_space"](
from_pixels=self._cfg["from_pixels"],
height=self._cfg["height"],
width=self._cfg["width"],
channels_first=self._cfg["channels_first"]
)
self._action_space = env_info["action_space"]
self._reward_space = env_info["reward_space"](self._cfg["frame_skip"])

def _process_obs(self, obs):
if self._cfg["from_pixels"]:
obs = to_ndarray(obs).astype(np.uint8)
else:
obs = to_ndarray(obs).astype(np.float32)
return obs

def step(self, action):
action = np.array([action]).astype('float32')
obs, reward, done, info = self.env.step(action)
return self._process_obs(obs), reward, done, info

def reset(self):
obs = self.env.reset()
return self._process_obs(obs)


def wrapped_dmc2gym_env(cfg):
default_cfg = {
"frame_skip": 3,
"from_pixels": True,
"visualize_reward": False,
"height": 100,
"width": 100,
"channels_first": True,
}
default_cfg.update(cfg)


return DingEnvWrapper(
dmc2gym.make(
domain_name=default_cfg["domain_name"],
task_name=default_cfg["task_name"],
seed=1,
visualize_reward=default_cfg["visualize_reward"],
from_pixels=default_cfg["from_pixels"],
height=default_cfg["height"],
width=default_cfg["width"],
frame_skip=default_cfg["frame_skip"]
)
,
cfg={
'env_wrapper': [
lambda env: Dmc2GymWrapper(env, default_cfg),
lambda env: FinalEvalRewardEnv(env),
]
}
)


def main(cfg, seed=0, max_env_step=int(1e10), max_train_iter=int(1e10)):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[partial(wrapped_dmc2gym_env, cfg=cartpole_balance_ppo_config.env)
for _ in range(collector_env_num)], cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[partial(wrapped_dmc2gym_env, cfg=cartpole_balance_ppo_config.env)
for _ in range(evaluator_env_num)], cfg=cfg.env.manager)

collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

while True:
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break


if __name__ == '__main__':
main(cartpole_balance_ppo_config)
94 changes: 94 additions & 0 deletions dizoo/minigrid/entry/minigrid_onppo_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import gym
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
from ding.model import VAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, FinalEvalRewardEnv, BaseEnvManager
from ding.config import compile_config
from ding.utils import set_pkg_seed

from dizoo.minigrid.config.minigrid_onppo_config import minigrid_ppo_config
from minigrid.wrappers import FlatObsWrapper
import numpy as np
from tensorboardX import SummaryWriter
import os
import gymnasium


class MinigridWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self._observation_space = gym.spaces.Box(
low=float("-inf"),
high=float("inf"),
shape=(8,),
dtype=np.float32
)
self._action_space = gym.spaces.Discrete(9)
self._action_space.seed(0) # default seed
self.reward_range = (float('-inf'), float('inf'))
self.max_steps = minigrid_ppo_config.env.max_step

def step(self, action):
obs, reward, done, _, info = self.env.step(action)
self.cur_step += 1
if self.cur_step > self.max_steps:
done = True
return obs, reward, done, info

def reset(self):
self.cur_step = 0
return self.env.reset()[0]


def wrapped_minigrid_env():
return DingEnvWrapper(
gymnasium.make(minigrid_ppo_config.env.env_id),
cfg={
'env_wrapper': [
lambda env: FlatObsWrapper(env),
lambda env: MinigridWrapper(env),
lambda env: FinalEvalRewardEnv(env),
]
}
)


def main(cfg, seed=0, max_env_step=int(1e10), max_train_iter=int(1e10)):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(collector_env_num)],
cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(evaluator_env_num)],
cfg=cfg.env.manager)

collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

while True:
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break


if __name__ == '__main__':
main(minigrid_ppo_config)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

bigfish_plr_default_config = dict(
bigfish_plr_config = dict(
exp_name='bigfish_plr_seed1',
env=dict(
is_train=True,
Expand Down Expand Up @@ -42,8 +42,8 @@
temperature=0.1,
),
)
bigfish_plr_default_config = EasyDict(bigfish_plr_default_config)
main_config = bigfish_plr_default_config
bigfish_plr_config = EasyDict(bigfish_plr_config)
main_config = bigfish_plr_config

bigfish_plr_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

bigfish_ppg_default_config = dict(
bigfish_ppg_config = dict(
exp_name='bigfish_ppg_seed0',
env=dict(
is_train=True,
Expand Down Expand Up @@ -37,8 +37,8 @@
other=dict(),
),
)
bigfish_ppg_default_config = EasyDict(bigfish_ppg_default_config)
main_config = bigfish_ppg_default_config
bigfish_ppg_config = EasyDict(bigfish_ppg_config)
main_config = bigfish_ppg_config

bigfish_ppg_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

coinrun_dqn_default_config = dict(
coinrun_dqn_config = dict(
env=dict(
env_id='coinrun',
collector_env_num=4,
Expand Down Expand Up @@ -36,8 +36,8 @@
),
),
)
coinrun_dqn_default_config = EasyDict(coinrun_dqn_default_config)
main_config = coinrun_dqn_default_config
coinrun_dqn_config = EasyDict(coinrun_dqn_config)
main_config = coinrun_dqn_config

coinrun_dqn_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

coinrun_ppg_default_config = dict(
coinrun_ppg_config = dict(
exp_name='coinrun_ppg_seed0',
env=dict(
is_train=True,
Expand Down Expand Up @@ -37,8 +37,8 @@
other=dict(),
),
)
coinrun_ppg_default_config = EasyDict(coinrun_ppg_default_config)
main_config = coinrun_ppg_default_config
coinrun_ppg_config = EasyDict(coinrun_ppg_config)
main_config = coinrun_ppg_config

coinrun_ppg_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

coinrun_ppo_default_config = dict(
coinrun_ppo_config = dict(
env=dict(
is_train=True,
env_id='coinrun',
Expand Down Expand Up @@ -39,8 +39,8 @@
),
),
)
coinrun_ppo_default_config = EasyDict(coinrun_ppo_default_config)
main_config = coinrun_ppo_default_config
coinrun_ppo_config = EasyDict(coinrun_ppo_config)
main_config = coinrun_ppo_config

coinrun_ppo_create_config = dict(
env=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

maze_dqn_default_config = dict(
maze_dqn_config = dict(
env=dict(
collector_env_num=4,
env_id='maze',
Expand Down Expand Up @@ -37,8 +37,8 @@
),
),
)
maze_dqn_default_config = EasyDict(maze_dqn_default_config)
main_config = maze_dqn_default_config
maze_dqn_config = EasyDict(maze_dqn_config)
main_config = maze_dqn_config

maze_dqn_create_config = dict(
env=dict(
Expand Down
Loading

0 comments on commit 049acb3

Please sign in to comment.