-
Notifications
You must be signed in to change notification settings - Fork 373
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(whl): add more DingEnvWrapper example (#525)
* wrap env for rocket and minigrid * wrap env for rocket and minigrid * add procgen & dmc2gym * polish pr according to Pu * polish code
- Loading branch information
Showing
15 changed files
with
601 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from easydict import EasyDict | ||
|
||
|
||
cartpole_balance_ppo_config = dict( | ||
exp_name='dmc2gym_cartpole_balance_ppo', | ||
env=dict( | ||
env_id='dmc2gym_cartpole_balance', | ||
domain_name='cartpole', | ||
task_name='balance', | ||
from_pixels=False, | ||
norm_obs=dict(use_norm=False, ), | ||
norm_reward=dict(use_norm=False, ), | ||
collector_env_num=1, | ||
evaluator_env_num=8, | ||
use_act_scale=True, | ||
n_evaluator_episode=8, | ||
stop_value=1000, | ||
), | ||
policy=dict( | ||
cuda=True, | ||
recompute_adv=True, | ||
action_space='discrete', | ||
model=dict( | ||
obs_shape=5, | ||
action_shape=1, | ||
action_space='discrete', | ||
encoder_hidden_size_list=[64, 64, 128], | ||
critic_head_hidden_size=128, | ||
actor_head_hidden_size=128, | ||
), | ||
learn=dict( | ||
epoch_per_collect=2, | ||
batch_size=64, | ||
learning_rate=0.001, | ||
value_weight=0.5, | ||
entropy_weight=0.01, | ||
clip_ratio=0.2, | ||
learner=dict(hook=dict(save_ckpt_after_iter=100)), | ||
), | ||
collect=dict( | ||
n_sample=256, | ||
unroll_len=1, | ||
discount_factor=0.9, | ||
gae_lambda=0.95, | ||
), | ||
other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ), | ||
) | ||
) | ||
cartpole_balance_ppo_config = EasyDict(cartpole_balance_ppo_config) | ||
main_config = cartpole_balance_ppo_config | ||
|
||
cartpole_balance_create_config = dict( | ||
env=dict( | ||
type='dmc2gym', | ||
import_names=['dizoo.dmc2gym.envs.dmc2gym_env'], | ||
), | ||
env_manager=dict(type='base'), | ||
policy=dict(type='ppo'), | ||
replay_buffer=dict(type='naive', ), | ||
) | ||
cartpole_balance_create_config = EasyDict(cartpole_balance_create_config) | ||
create_config = cartpole_balance_create_config | ||
|
||
# To use this config, you can enter dizoo/dmc2gym/entry to call dmc2gym_onppo_main.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import os | ||
from easydict import EasyDict | ||
from functools import partial | ||
from tensorboardX import SummaryWriter | ||
import dmc2gym | ||
|
||
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator | ||
from ding.model import VAC | ||
from ding.policy import PPOPolicy | ||
from ding.envs import DingEnvWrapper, FinalEvalRewardEnv, BaseEnvManager | ||
from ding.config import compile_config | ||
from ding.utils import set_pkg_seed | ||
from dizoo.dmc2gym.config.dmc2gym_ppo_config import cartpole_balance_ppo_config | ||
from dizoo.dmc2gym.envs.dmc2gym_env import * | ||
|
||
|
||
class Dmc2GymWrapper(gym.Wrapper): | ||
def __init__(self, env, cfg): | ||
super().__init__(env) | ||
cfg = EasyDict(cfg) | ||
self._cfg = cfg | ||
|
||
env_info = dmc2gym_env_info[cfg.domain_name][cfg.task_name] | ||
|
||
self._observation_space = env_info["observation_space"]( | ||
from_pixels=self._cfg["from_pixels"], | ||
height=self._cfg["height"], | ||
width=self._cfg["width"], | ||
channels_first=self._cfg["channels_first"] | ||
) | ||
self._action_space = env_info["action_space"] | ||
self._reward_space = env_info["reward_space"](self._cfg["frame_skip"]) | ||
|
||
def _process_obs(self, obs): | ||
if self._cfg["from_pixels"]: | ||
obs = to_ndarray(obs).astype(np.uint8) | ||
else: | ||
obs = to_ndarray(obs).astype(np.float32) | ||
return obs | ||
|
||
def step(self, action): | ||
action = np.array([action]).astype('float32') | ||
obs, reward, done, info = self.env.step(action) | ||
return self._process_obs(obs), reward, done, info | ||
|
||
def reset(self): | ||
obs = self.env.reset() | ||
return self._process_obs(obs) | ||
|
||
|
||
def wrapped_dmc2gym_env(cfg): | ||
default_cfg = { | ||
"frame_skip": 3, | ||
"from_pixels": True, | ||
"visualize_reward": False, | ||
"height": 100, | ||
"width": 100, | ||
"channels_first": True, | ||
} | ||
default_cfg.update(cfg) | ||
|
||
|
||
return DingEnvWrapper( | ||
dmc2gym.make( | ||
domain_name=default_cfg["domain_name"], | ||
task_name=default_cfg["task_name"], | ||
seed=1, | ||
visualize_reward=default_cfg["visualize_reward"], | ||
from_pixels=default_cfg["from_pixels"], | ||
height=default_cfg["height"], | ||
width=default_cfg["width"], | ||
frame_skip=default_cfg["frame_skip"] | ||
) | ||
, | ||
cfg={ | ||
'env_wrapper': [ | ||
lambda env: Dmc2GymWrapper(env, default_cfg), | ||
lambda env: FinalEvalRewardEnv(env), | ||
] | ||
} | ||
) | ||
|
||
|
||
def main(cfg, seed=0, max_env_step=int(1e10), max_train_iter=int(1e10)): | ||
cfg = compile_config( | ||
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, | ||
save_cfg=True | ||
) | ||
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num | ||
collector_env = BaseEnvManager(env_fn=[partial(wrapped_dmc2gym_env, cfg=cartpole_balance_ppo_config.env) | ||
for _ in range(collector_env_num)], cfg=cfg.env.manager) | ||
evaluator_env = BaseEnvManager(env_fn=[partial(wrapped_dmc2gym_env, cfg=cartpole_balance_ppo_config.env) | ||
for _ in range(evaluator_env_num)], cfg=cfg.env.manager) | ||
|
||
collector_env.seed(seed) | ||
evaluator_env.seed(seed, dynamic_seed=False) | ||
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) | ||
|
||
model = VAC(**cfg.policy.model) | ||
policy = PPOPolicy(cfg.policy, model=model) | ||
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) | ||
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) | ||
collector = SampleSerialCollector( | ||
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name | ||
) | ||
evaluator = InteractionSerialEvaluator( | ||
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name | ||
) | ||
|
||
while True: | ||
if evaluator.should_eval(learner.train_iter): | ||
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) | ||
if stop: | ||
break | ||
new_data = collector.collect(train_iter=learner.train_iter) | ||
learner.train(new_data, collector.envstep) | ||
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: | ||
break | ||
|
||
|
||
if __name__ == '__main__': | ||
main(cartpole_balance_ppo_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import gym | ||
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator | ||
from ding.model import VAC | ||
from ding.policy import PPOPolicy | ||
from ding.envs import DingEnvWrapper, FinalEvalRewardEnv, BaseEnvManager | ||
from ding.config import compile_config | ||
from ding.utils import set_pkg_seed | ||
|
||
from dizoo.minigrid.config.minigrid_onppo_config import minigrid_ppo_config | ||
from minigrid.wrappers import FlatObsWrapper | ||
import numpy as np | ||
from tensorboardX import SummaryWriter | ||
import os | ||
import gymnasium | ||
|
||
|
||
class MinigridWrapper(gym.Wrapper): | ||
def __init__(self, env): | ||
super().__init__(env) | ||
self._observation_space = gym.spaces.Box( | ||
low=float("-inf"), | ||
high=float("inf"), | ||
shape=(8,), | ||
dtype=np.float32 | ||
) | ||
self._action_space = gym.spaces.Discrete(9) | ||
self._action_space.seed(0) # default seed | ||
self.reward_range = (float('-inf'), float('inf')) | ||
self.max_steps = minigrid_ppo_config.env.max_step | ||
|
||
def step(self, action): | ||
obs, reward, done, _, info = self.env.step(action) | ||
self.cur_step += 1 | ||
if self.cur_step > self.max_steps: | ||
done = True | ||
return obs, reward, done, info | ||
|
||
def reset(self): | ||
self.cur_step = 0 | ||
return self.env.reset()[0] | ||
|
||
|
||
def wrapped_minigrid_env(): | ||
return DingEnvWrapper( | ||
gymnasium.make(minigrid_ppo_config.env.env_id), | ||
cfg={ | ||
'env_wrapper': [ | ||
lambda env: FlatObsWrapper(env), | ||
lambda env: MinigridWrapper(env), | ||
lambda env: FinalEvalRewardEnv(env), | ||
] | ||
} | ||
) | ||
|
||
|
||
def main(cfg, seed=0, max_env_step=int(1e10), max_train_iter=int(1e10)): | ||
cfg = compile_config( | ||
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, | ||
save_cfg=True | ||
) | ||
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num | ||
collector_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(collector_env_num)], | ||
cfg=cfg.env.manager) | ||
evaluator_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(evaluator_env_num)], | ||
cfg=cfg.env.manager) | ||
|
||
collector_env.seed(seed) | ||
evaluator_env.seed(seed, dynamic_seed=False) | ||
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) | ||
|
||
model = VAC(**cfg.policy.model) | ||
policy = PPOPolicy(cfg.policy, model=model) | ||
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) | ||
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) | ||
collector = SampleSerialCollector( | ||
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name | ||
) | ||
evaluator = InteractionSerialEvaluator( | ||
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name | ||
) | ||
|
||
while True: | ||
if evaluator.should_eval(learner.train_iter): | ||
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) | ||
if stop: | ||
break | ||
new_data = collector.collect(train_iter=learner.train_iter) | ||
learner.train(new_data, collector.envstep) | ||
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: | ||
break | ||
|
||
|
||
if __name__ == '__main__': | ||
main(minigrid_ppo_config) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.