Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(gry): add acrobot env and dqn config #577

Merged
merged 6 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 30 |[evogym](https://github.com/EvolutionGym/evogym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/evogym/evogym.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/evogym/envs) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/evogym.html) <br>环境指南 |
| 31 |[gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/gym-pybullet-drones/gym-pybullet-drones.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_pybullet_drones/envs)<br>环境指南 |
| 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs)<br>环境指南 |
| 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)<br>环境指南 |

![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space

Expand Down
Empty file.
Binary file added dizoo/classic_control/acrobot/acrobot.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions dizoo/classic_control/acrobot/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .acrobot_dqn_config import acrobot_dqn_config, acrobot_dqn_create_config
55 changes: 55 additions & 0 deletions dizoo/classic_control/acrobot/config/acrobot_dqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from easydict import EasyDict

acrobot_dqn_config = dict(
exp_name='acrobot_dqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=-60,
env_id='Acrobot-v1',
replay_path='acrobot_dqn_seed0/video',
),
policy=dict(
cuda=True,
model=dict(
obs_shape=6,
action_shape=3,
encoder_hidden_size_list=[256, 256],
dueling=True,
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=128,
learning_rate=0.0001,
target_update_freq=250,
),
collect=dict(n_sample=96, ),
eval=dict(evaluator=dict(eval_freq=2000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
acrobot_dqn_config = EasyDict(acrobot_dqn_config)
main_config = acrobot_dqn_config
acrobot_dqn_create_config = dict(
env=dict(type='acrobot', import_names=['dizoo.classic_control.acrobot.envs.acrobot_env']),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
)
acrobot_dqn_create_config = EasyDict(acrobot_dqn_create_config)
create_config = acrobot_dqn_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
1 change: 1 addition & 0 deletions dizoo/classic_control/acrobot/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .acrobot_env import AcroBotEnv
98 changes: 98 additions & 0 deletions dizoo/classic_control/acrobot/envs/acrobot_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Any, List, Union, Optional
import time
import gym
import copy
import numpy as np
from easydict import EasyDict
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.torch_utils import to_ndarray, to_list
from ding.utils import ENV_REGISTRY
from ding.envs import ObsPlusPrevActRewWrapper


@ENV_REGISTRY.register('acrobot')
class AcroBotEnv(BaseEnv):

def __init__(self, cfg: dict = {}) -> None:
self._cfg = cfg
self._init_flag = False
self._replay_path = None
self._observation_space = gym.spaces.Box(
low=np.array([-1.0, -1.0, -1.0, -1.0, -12.57, -28.27]),
high=np.array([1.0, 1.0, 1.0, 1.0, 12.57, 28.27]),
shape=(6, ),
dtype=np.float32
)
self._action_space = gym.spaces.Discrete(3)
self._action_space.seed(0) # default seed
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
self._reward_space = gym.spaces.Box(low=-1.0, high=0.0, shape=(1, ), dtype=np.float32)

def reset(self) -> np.ndarray:
if not self._init_flag:
self._env = gym.make('Acrobot-v1')
if self._replay_path is not None:
self._env = gym.wrappers.RecordVideo(
self._env,
video_folder=self._replay_path,
episode_trigger=lambda episode_id: True,
name_prefix='rl-video-{}'.format(id(self))
)
self._init_flag = True
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env.seed(self._seed + np_seed)
self._action_space.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
self._env.seed(self._seed)
self._action_space.seed(self._seed)
self._observation_space = self._env.observation_space
self._eval_episode_return = 0
obs = self._env.reset()
obs = to_ndarray(obs)
return obs

def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
if isinstance(action, np.ndarray) and action.shape == (1, ):
action = action.squeeze() # 0-dim array
obs, rew, done, info = self._env.step(action)
self._eval_episode_return += rew
if done:
info['eval_episode_return'] = self._eval_episode_return
obs = to_ndarray(obs)
rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,)
return BaseEnvTimestep(obs, rew, done, info)

def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path

def random_action(self) -> np.ndarray:
random_action = self.action_space.sample()
random_action = to_ndarray([random_action], dtype=np.int64)
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
return random_action

@property
def observation_space(self) -> gym.spaces.Space:
return self._observation_space

@property
def action_space(self) -> gym.spaces.Space:
return self._action_space

@property
def reward_space(self) -> gym.spaces.Space:
return self._reward_space

def __repr__(self) -> str:
return "DI-engine Acrobot Env"
35 changes: 35 additions & 0 deletions dizoo/classic_control/acrobot/envs/test_acrobot_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import numpy as np
from dizoo.classic_control.acrobot.envs import AcroBotEnv


@pytest.mark.envtest
class TestAcrobotEnv:

def test_naive(self):
env = AcroBotEnv({})
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (6, )
for _ in range(5):
env.reset()
np.random.seed(314)
print('=' * 60)
for i in range(10):
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
# can generate legal random action.
if i < 5:
random_action = np.array([env.action_space.sample()])
else:
random_action = env.random_action()
timestep = env.step(random_action)
print(timestep)
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (6, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high
print(env.observation_space, env.action_space, env.reward_space)
env.close()