Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add more DingEnvWrapper example #525

Merged
merged 7 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions dizoo/dmc2gym/config/dmc2gym_ppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from easydict import EasyDict


cartpole_balance_ppo_config = dict(
exp_name='dmc2gym_cartpole_balance_ppo',
env=dict(
env_id='dmc2gym_cartpole_balance',
domain_name='cartpole',
task_name='balance',
from_pixels=False,
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=1000,
),
policy=dict(
cuda=True,
recompute_adv=True,
action_space='discrete',
model=dict(
obs_shape=5,
action_shape=1,
action_space='discrete',
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
),
learn=dict(
epoch_per_collect=2,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(hook=dict(save_ckpt_after_iter=100)),
),
collect=dict(
n_sample=256,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
),
other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
)
)
cartpole_balance_ppo_config = EasyDict(cartpole_balance_ppo_config)
main_config = cartpole_balance_ppo_config

cartpole_balance_create_config = dict(
env=dict(
type='dmc2gym',
import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
replay_buffer=dict(type='naive', ),
)
cartpole_balance_create_config = EasyDict(cartpole_balance_create_config)
create_config = cartpole_balance_create_config
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
125 changes: 125 additions & 0 deletions dizoo/dmc2gym/entry/ppo_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
from ding.model import VAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, FinalEvalRewardEnv, BaseEnvManager
from ding.config import compile_config
from ding.utils import set_pkg_seed

from dizoo.dmc2gym.config.dmc2gym_ppo_config import cartpole_balance_ppo_config
from tensorboardX import SummaryWriter
import os
import dmc2gym
from dizoo.dmc2gym.envs.dmc2gym_env import *
from easydict import EasyDict


class InfoWrapper(gym.Wrapper):
def __init__(self, env, cfg):
super().__init__(env)
cfg = EasyDict(cfg)
self._cfg = cfg
self._observation_space = dmc2gym_env_info[cfg.domain_name][cfg.task_name]["observation_space"](
from_pixels=self._cfg["from_pixels"],
height=self._cfg["height"],
width=self._cfg["width"],
channels_first=self._cfg["channels_first"]
)
self._action_space = dmc2gym_env_info[cfg.domain_name][cfg.task_name]["action_space"]
self._reward_space = dmc2gym_env_info[cfg.domain_name][cfg.task_name]["reward_space"](self._cfg["frame_skip"])

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def _process_obs(self, obs):
if self._cfg["from_pixels"]:
obs = to_ndarray(obs).astype(np.uint8)
else:
obs = to_ndarray(obs).astype(np.float32)
return obs

def step(self, action):
action = np.array([action]).astype('float32')
obs, reward, done, info = self.env.step(action)
return self._process_obs(obs), reward, done, info

def reset(self):
obs = self.env.reset()
return self._process_obs(obs)


def wrapped_dmc2gym_env(cfg):
default_cfg = {
"frame_skip": 3,
"from_pixels": True,
"visualize_reward": False,
"height": 100,
"width": 100,
"channels_first": True,
}
default_cfg.update(cfg)


return DingEnvWrapper(
dmc2gym.make(
domain_name=default_cfg["domain_name"],
task_name=default_cfg["task_name"],
seed=1,
visualize_reward=default_cfg["visualize_reward"],
from_pixels=default_cfg["from_pixels"],
height=default_cfg["height"],
width=default_cfg["width"],
frame_skip=default_cfg["frame_skip"]
)
,
cfg={
'env_wrapper': [
lambda env: InfoWrapper(env, default_cfg),
lambda env: FinalEvalRewardEnv(env),
]
}
)


def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[lambda: wrapped_dmc2gym_env(cartpole_balance_ppo_config.env) for _ in range(collector_env_num)],
cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[lambda: wrapped_dmc2gym_env(cartpole_balance_ppo_config.env) for _ in range(evaluator_env_num)],
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
cfg=cfg.env.manager)
# collector_env = BaseEnvManager(env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(collector_env_num)],
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
# cfg=cfg.env.manager)
# evaluator_env = BaseEnvManager(env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(evaluator_env_num)],
# cfg=cfg.env.manager)

collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)


if __name__ == '__main__':
main(cartpole_balance_ppo_config)
97 changes: 97 additions & 0 deletions dizoo/minigrid/entry/minigrid_ppo_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import gym
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
from ding.model import VAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, FinalEvalRewardEnv, BaseEnvManager
from ding.config import compile_config
from ding.utils import set_pkg_seed

from dizoo.minigrid.config.minigrid_onppo_config import minigrid_ppo_config
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
from minigrid.wrappers import FlatObsWrapper
import numpy as np
from tensorboardX import SummaryWriter
import os
import gymnasium


class InfoWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self._observation_space = gym.spaces.Box(
low=float("-inf"),
high=float("inf"),
shape=(8,),
dtype=np.float32
)
self._action_space = gym.spaces.Discrete(9)
self._action_space.seed(0) # default seed
self.reward_range = (float('-inf'), float('inf'))
self.max_steps = minigrid_ppo_config.env.max_step

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def step(self, action):
obs, reward, done, _, info = self.env.step(action)
self.cur_step += 1
if self.cur_step > self.max_steps:
done = True
return obs, reward, done, info

def reset(self):
self.cur_step = 0
return self.env.reset()[0]


def wrapped_minigrid_env():
return DingEnvWrapper(
gymnasium.make(minigrid_ppo_config.env.env_id),
cfg={
'env_wrapper': [
lambda env: FlatObsWrapper(env),
lambda env: InfoWrapper(env),
lambda env: FinalEvalRewardEnv(env),
]
}
)


def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(collector_env_num)],
cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(evaluator_env_num)],
cfg=cfg.env.manager)

collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)


if __name__ == '__main__':
main(minigrid_ppo_config)
File renamed without changes.
112 changes: 112 additions & 0 deletions dizoo/procgen/entry/coinrun_ppo_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
from ding.model import VAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, FinalEvalRewardEnv, BaseEnvManager
from ding.config import compile_config
from ding.utils import set_pkg_seed

from dizoo.procgen.envs import ProcgenEnv
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
from tensorboardX import SummaryWriter
import os
from dizoo.procgen.config.coinrun_ppo_config import coinrun_ppo_default_config
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
from dizoo.dmc2gym.envs.dmc2gym_env import *
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
from easydict import EasyDict


class InfoWrapper(gym.Wrapper):
def __init__(self, env, cfg):
super().__init__(env)
cfg = EasyDict(cfg)
self._cfg = cfg
self._observation_space = gym.spaces.Box(
low=np.zeros(shape=(3, 64, 64)), high=np.ones(shape=(3, 64, 64)) * 255, shape=(3, 64, 64), dtype=np.float32
)
self._action_space = gym.spaces.Discrete(15)
self._reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def _process_obs(self, obs):
obs = to_ndarray(obs)
obs = np.transpose(obs, (2, 0, 1))
obs = obs.astype(np.float32)
return obs

def step(self, action):
obs, reward, done, info = self.env.step(action)
return self._process_obs(obs), reward, bool(done), info

def reset(self):
obs = self.env.reset()
return self._process_obs(obs)


def wrapped_procgen_env(cfg):
default_cfg = dict(
control_level=True,
start_level=0,
num_levels=0,
env_id='coinrun',
)
default_cfg.update(cfg)
default_cfg = EasyDict(default_cfg)

return DingEnvWrapper(
gym.make('procgen:procgen-' + default_cfg.env_id + '-v0',
start_level=default_cfg.start_level,
num_levels=default_cfg.num_levels) if default_cfg.control_level else
gym.make('procgen:procgen-' + default_cfg.env_id + '-v0', start_level=0, num_levels=1)
,
cfg={
'env_wrapper': [
lambda env: InfoWrapper(env, default_cfg),
lambda env: FinalEvalRewardEnv(env),
]
}
)


def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[lambda: wrapped_procgen_env(coinrun_ppo_default_config.env) for _ in range(collector_env_num)],
cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[lambda: wrapped_procgen_env(coinrun_ppo_default_config.env) for _ in range(evaluator_env_num)],
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
cfg=cfg.env.manager)
# collector_env = BaseEnvManager(env_fn=[lambda: ProcgenEnv(cfg.env) for _ in range(collector_env_num)],
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
# cfg=cfg.env.manager)
# evaluator_env = BaseEnvManager(env_fn=[lambda: ProcgenEnv(cfg.env) for _ in range(evaluator_env_num)],
# cfg=cfg.env.manager)

collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)


if __name__ == '__main__':
main(coinrun_ppo_default_config)
Loading