-
-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/RLE-Foundation/rllte
- Loading branch information
Showing
9 changed files
with
379 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
Integrating RL environments in RLLTE is incredibly easy and efficient! | ||
|
||
## Menu | ||
1. [Installation](#installation) | ||
2. [Usage](#usage) | ||
|
||
## Installation | ||
|
||
Assuming you are running inside a conda environment. | ||
|
||
### Atari | ||
``` | ||
pip install ale-py==0.8.1 | ||
``` | ||
|
||
### Craftax | ||
|
||
You will need a Jax GPU-enabled conda environment: | ||
|
||
``` | ||
conda create -n rllte jaxlib==*cuda jax python=3.11 -c conda-forge | ||
pip install craftax | ||
pip install brax | ||
pip install -e .[envs] | ||
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
``` | ||
|
||
### DMC | ||
``` | ||
pip install dm-control | ||
``` | ||
|
||
### SuperMarioBros | ||
``` | ||
pip install gym-super-mario-bros==7.4.0 | ||
``` | ||
|
||
### Minigrid | ||
``` | ||
pip install minigrid | ||
``` | ||
|
||
### Miniworld | ||
``` | ||
pip install miniworld | ||
``` | ||
|
||
### Procgen | ||
``` | ||
pip install procgen | ||
``` | ||
|
||
### Envpool | ||
``` | ||
pip install envpool | ||
``` | ||
|
||
## Usage | ||
|
||
Each environment has a `make_env()` function in `rllte/env/<your_RL_env>/__init__.py` and its necessary wrappers in `rllte/env/<your_RL_env>/wrappers.py`. To add your custom environments, simply follow the same logic as the currently available environments, and the RL training will work flawlessly! | ||
|
||
## Example training | ||
|
||
``` | ||
from rllte.agent import PPO | ||
from rllte.env import ( | ||
make_mario_env, | ||
make_envpool_vizdoom_env, | ||
make_envpool_procgen_env, | ||
make_minigrid_env, | ||
make_envpool_atari_env, | ||
make_craftax_env | ||
) | ||
# define params | ||
device = "cuda" | ||
# define environment | ||
env = make_craftax_env( | ||
num_envs=32, | ||
device=device, | ||
) | ||
# define agent | ||
agent = PPO( | ||
env=env, | ||
device=device | ||
) | ||
# start training | ||
agent.train( | ||
num_train_steps=10_000_000, | ||
) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from craftax.envs.craftax_pixels_env import CraftaxPixelsEnv | ||
from craftax_classic.envs.craftax_pixels_env import CraftaxClassicPixelsEnv | ||
from environment_base.wrappers import ( | ||
LogWrapper, | ||
BatchEnvWrapper, | ||
OptimisticResetVecEnvWrapper, | ||
) | ||
|
||
from rllte.env.craftax.wrappers import TorchWrapper, ResizeTorchWrapper, RecordEpisodeStatistics4Craftax | ||
|
||
def make_craftax_env( | ||
env_id: str = "Craftax-Classic", | ||
num_envs: int = 32, | ||
reset_ratio: int = 16, | ||
device: str = "cpu", | ||
): | ||
|
||
if env_id == "Craftax-Classic": | ||
env = CraftaxClassicPixelsEnv() | ||
elif env_id == "Craftax": | ||
env = CraftaxPixelsEnv() | ||
else: | ||
raise ValueError(f"Unknown environment: {env_id}") | ||
|
||
env = LogWrapper(env) | ||
env = OptimisticResetVecEnvWrapper(env, num_envs=num_envs, reset_ratio=reset_ratio) | ||
env = TorchWrapper(env, device=device) | ||
env = ResizeTorchWrapper(env, (84, 84)) | ||
env = RecordEpisodeStatistics4Craftax(env) | ||
env.num_envs = num_envs | ||
env.single_observation_space = env.observation_space | ||
env.single_action_space = env.action_space | ||
return env | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import numpy as np | ||
import jax | ||
import gymnasium as gym | ||
import torch | ||
from dataclasses import asdict | ||
from brax.io import torch as brax_torch | ||
|
||
class TorchWrapper(gym.Wrapper): | ||
"""Wrapper that converts Jax tensors to PyTorch tensors.""" | ||
|
||
def __init__(self, env, device): | ||
super().__init__(env) | ||
self.device = device | ||
self.env = env | ||
self.default_params = env.default_params | ||
self.metadata = { | ||
'render.modes': ['human', 'rgb_array'], | ||
} | ||
|
||
# define obs and action space | ||
obs_shape = env.observation_space(self.default_params).shape | ||
self.observation_space = gym.spaces.Box( | ||
low=-1e6, high=1e6, shape=obs_shape) | ||
self.action_space = gym.spaces.Discrete(env.action_space(self.default_params).n) | ||
|
||
# jit the reset function | ||
def reset(key): | ||
key1, key2 = jax.random.split(key) | ||
obs, state = self.env.reset(key2) | ||
return state, obs, key1, asdict(state) | ||
self._reset = jax.jit(reset) | ||
|
||
# jit the step function | ||
def step(state, action): | ||
obs, env_state, reward, done, info = self.env.step(rng=self._key, state=state, action=action) | ||
return env_state, obs, reward, done, {**asdict(env_state), **info} | ||
self._step = jax.jit(step) | ||
|
||
def reset(self, seed=0, options=None): | ||
self.seed(seed) | ||
self._state, obs, self._key, info = self._reset(self._key) | ||
return brax_torch.jax_to_torch(obs, device=self.device), info | ||
|
||
def step(self, action): | ||
action = brax_torch.torch_to_jax(action) | ||
self._state, obs, reward, done, info = self._step(self._state, action) | ||
obs = brax_torch.jax_to_torch(obs, device=self.device) | ||
reward = brax_torch.jax_to_torch(reward, device=self.device) | ||
terminateds = brax_torch.jax_to_torch(done, device=self.device) | ||
truncateds = brax_torch.jax_to_torch(done, device=self.device) | ||
info = brax_torch.jax_to_torch(info, device=self.device) | ||
return obs, reward, terminateds, truncateds, info | ||
|
||
def seed(self, seed: int = 0): | ||
self._key = jax.random.PRNGKey(seed) | ||
|
||
class ResizeTorchWrapper(gym.Wrapper): | ||
"""Wrapper that resizes observations to a given shape.""" | ||
|
||
def __init__(self, env, shape): | ||
super().__init__(env) | ||
self.env = env | ||
num_channels = env.observation_space.shape[-1] | ||
self.shape = (num_channels, shape[0], shape[1]) | ||
|
||
# define obs and action space | ||
self.observation_space = gym.spaces.Box( | ||
low=-1e6, high=1e6, shape=self.shape) | ||
|
||
def reset(self, seed=0, options=None): | ||
obs, info = self.env.reset(seed, options) | ||
obs = obs.permute(0, 3, 1, 2) | ||
obs = torch.nn.functional.interpolate(obs, size=self.shape[1:], mode='nearest') | ||
return obs, info | ||
|
||
def step(self, action): | ||
obs, reward, terminateds, truncateds, info = self.env.step(action) | ||
obs = obs.permute(0, 3, 1, 2) | ||
obs = torch.nn.functional.interpolate(obs, size=self.shape[1:], mode='nearest') | ||
return obs, reward, terminateds, truncateds, info | ||
|
||
class RecordEpisodeStatistics4Craftax(gym.Wrapper): | ||
def __init__(self, env: gym.Env, deque_size: int = 100) -> None: | ||
super().__init__(env) | ||
self.num_envs = getattr(env, "num_envs", 1) | ||
self.episode_returns = None | ||
self.episode_lengths = None | ||
|
||
def reset(self, **kwargs): | ||
observations, infos = super().reset(**kwargs) | ||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) | ||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) | ||
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) | ||
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) | ||
return observations, infos | ||
|
||
def step(self, actions): | ||
observations, rewards, terms, truncs, infos = super().step(actions) | ||
self.episode_returns += rewards.cpu().numpy() | ||
self.episode_lengths += 1 | ||
self.returned_episode_returns[:] = self.episode_returns | ||
self.returned_episode_lengths[:] = self.episode_lengths | ||
self.episode_returns *= 1 - infos["returned_episode"].cpu().numpy().astype(np.int32) | ||
self.episode_lengths *= 1 - infos["returned_episode"].cpu().numpy().astype(np.int32) | ||
infos["episode"] = {} | ||
infos["episode"]["r"] = self.returned_episode_returns | ||
infos["episode"]["l"] = self.returned_episode_lengths | ||
|
||
for idx, d in enumerate(terms): | ||
if not d: | ||
infos["episode"]["r"][idx] = 0 | ||
infos["episode"]["l"][idx] = 0 | ||
|
||
return observations, rewards, terms, truncs, infos |
Oops, something went wrong.