diff --git a/README.md b/README.md index 3601d9e77c..127b15745c 100644 --- a/README.md +++ b/README.md @@ -282,6 +282,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` | 30 |[evogym](https://github.com/EvolutionGym/evogym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/evogym/evogym.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/evogym/envs)
[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/evogym.html)
环境指南 | | 31 |[gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/gym-pybullet-drones/gym-pybullet-drones.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_pybullet_drones/envs)
环境指南 | | 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs)
环境指南 | +| 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)
环境指南 | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space diff --git a/dizoo/classic_control/acrobot/__init__.py b/dizoo/classic_control/acrobot/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dizoo/classic_control/acrobot/acrobot.gif b/dizoo/classic_control/acrobot/acrobot.gif new file mode 100644 index 0000000000..3eef302afe Binary files /dev/null and b/dizoo/classic_control/acrobot/acrobot.gif differ diff --git a/dizoo/classic_control/acrobot/config/__init__.py b/dizoo/classic_control/acrobot/config/__init__.py new file mode 100644 index 0000000000..036dbf6a93 --- /dev/null +++ b/dizoo/classic_control/acrobot/config/__init__.py @@ -0,0 +1 @@ +from .acrobot_dqn_config import acrobot_dqn_config, acrobot_dqn_create_config diff --git a/dizoo/classic_control/acrobot/config/acrobot_dqn_config.py b/dizoo/classic_control/acrobot/config/acrobot_dqn_config.py new file mode 100644 index 0000000000..4957db987f --- /dev/null +++ b/dizoo/classic_control/acrobot/config/acrobot_dqn_config.py @@ -0,0 +1,55 @@ +from easydict import EasyDict + +acrobot_dqn_config = dict( + exp_name='acrobot_dqn_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=-60, + env_id='Acrobot-v1', + replay_path='acrobot_dqn_seed0/video', + ), + policy=dict( + cuda=True, + model=dict( + obs_shape=6, + action_shape=3, + encoder_hidden_size_list=[256, 256], + dueling=True, + ), + nstep=3, + discount_factor=0.99, + learn=dict( + update_per_collect=10, + batch_size=128, + learning_rate=0.0001, + target_update_freq=250, + ), + collect=dict(n_sample=96, ), + eval=dict(evaluator=dict(eval_freq=2000, )), + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=250000, + ), + replay_buffer=dict(replay_buffer_size=100000, ), + ), + ), +) +acrobot_dqn_config = EasyDict(acrobot_dqn_config) +main_config = acrobot_dqn_config +acrobot_dqn_create_config = dict( + env=dict(type='acrobot', import_names=['dizoo.classic_control.acrobot.envs.acrobot_env']), + env_manager=dict(type='subprocess'), + policy=dict(type='dqn'), + replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']), +) +acrobot_dqn_create_config = EasyDict(acrobot_dqn_create_config) +create_config = acrobot_dqn_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/classic_control/acrobot/envs/__init__.py b/dizoo/classic_control/acrobot/envs/__init__.py new file mode 100644 index 0000000000..be6537f2c9 --- /dev/null +++ b/dizoo/classic_control/acrobot/envs/__init__.py @@ -0,0 +1 @@ +from .acrobot_env import AcroBotEnv diff --git a/dizoo/classic_control/acrobot/envs/acrobot_env.py b/dizoo/classic_control/acrobot/envs/acrobot_env.py new file mode 100644 index 0000000000..3c26323315 --- /dev/null +++ b/dizoo/classic_control/acrobot/envs/acrobot_env.py @@ -0,0 +1,98 @@ +from typing import Any, List, Union, Optional +import time +import gym +import copy +import numpy as np +from easydict import EasyDict +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.torch_utils import to_ndarray, to_list +from ding.utils import ENV_REGISTRY +from ding.envs import ObsPlusPrevActRewWrapper + + +@ENV_REGISTRY.register('acrobot') +class AcroBotEnv(BaseEnv): + + def __init__(self, cfg: dict = {}) -> None: + self._cfg = cfg + self._init_flag = False + self._replay_path = None + self._observation_space = gym.spaces.Box( + low=np.array([-1.0, -1.0, -1.0, -1.0, -12.57, -28.27]), + high=np.array([1.0, 1.0, 1.0, 1.0, 12.57, 28.27]), + shape=(6, ), + dtype=np.float32 + ) + self._action_space = gym.spaces.Discrete(3) + self._action_space.seed(0) # default seed + self._reward_space = gym.spaces.Box(low=-1.0, high=0.0, shape=(1, ), dtype=np.float32) + + def reset(self) -> np.ndarray: + if not self._init_flag: + self._env = gym.make('Acrobot-v1') + if self._replay_path is not None: + self._env = gym.wrappers.RecordVideo( + self._env, + video_folder=self._replay_path, + episode_trigger=lambda episode_id: True, + name_prefix='rl-video-{}'.format(id(self)) + ) + self._init_flag = True + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: + np_seed = 100 * np.random.randint(1, 1000) + self._env.seed(self._seed + np_seed) + self._action_space.seed(self._seed + np_seed) + elif hasattr(self, '_seed'): + self._env.seed(self._seed) + self._action_space.seed(self._seed) + self._observation_space = self._env.observation_space + self._eval_episode_return = 0 + obs = self._env.reset() + obs = to_ndarray(obs) + return obs + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: + if isinstance(action, np.ndarray) and action.shape == (1, ): + action = action.squeeze() # 0-dim array + obs, rew, done, info = self._env.step(action) + self._eval_episode_return += rew + if done: + info['eval_episode_return'] = self._eval_episode_return + obs = to_ndarray(obs) + rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,) + return BaseEnvTimestep(obs, rew, done, info) + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + + def random_action(self) -> np.ndarray: + random_action = self.action_space.sample() + random_action = to_ndarray([random_action], dtype=np.int64) + return random_action + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space + + def __repr__(self) -> str: + return "DI-engine Acrobot Env" diff --git a/dizoo/classic_control/acrobot/envs/test_acrobot_env.py b/dizoo/classic_control/acrobot/envs/test_acrobot_env.py new file mode 100644 index 0000000000..fba0914cfa --- /dev/null +++ b/dizoo/classic_control/acrobot/envs/test_acrobot_env.py @@ -0,0 +1,35 @@ +import pytest +import numpy as np +from dizoo.classic_control.acrobot.envs import AcroBotEnv + + +@pytest.mark.envtest +class TestAcrobotEnv: + + def test_naive(self): + env = AcroBotEnv({}) + env.seed(314, dynamic_seed=False) + assert env._seed == 314 + obs = env.reset() + assert obs.shape == (6, ) + for _ in range(5): + env.reset() + np.random.seed(314) + print('=' * 60) + for i in range(10): + # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space, + # can generate legal random action. + if i < 5: + random_action = np.array([env.action_space.sample()]) + else: + random_action = env.random_action() + timestep = env.step(random_action) + print(timestep) + assert isinstance(timestep.obs, np.ndarray) + assert isinstance(timestep.done, bool) + assert timestep.obs.shape == (6, ) + assert timestep.reward.shape == (1, ) + assert timestep.reward >= env.reward_space.low + assert timestep.reward <= env.reward_space.high + print(env.observation_space, env.action_space, env.reward_space) + env.close()