Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanmo committed Mar 11, 2024
2 parents 766004a + 5b984b0 commit ef50681
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 15 deletions.
12 changes: 3 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ packages = ["rllte"]

[project]
name = "rllte-core"
version = "0.0.1.beta12"
version = "0.0.1.beta13"
authors = [
{ name="Reinforcement Learning Evolution Foundation", email="friedrichyuan19990827@gmail.com" },
]
Expand All @@ -33,7 +33,7 @@ classifiers = [
]

dependencies = [
"gymnasium[accept-rom-license]",
"gymnasium[accept-rom-license, other]",
"torch",
"torchvision",
"termcolor",
Expand All @@ -56,13 +56,7 @@ tests = [
"isort>=5.0",
"black"
]
envs = [
"envpool",
"ale-py==0.8.1",
"dm-control",
"procgen",
"minigrid"
]

docs = [
"mkdocs-material",
"mkgendocs"
Expand Down
19 changes: 15 additions & 4 deletions rllte/common/prototype/base_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from rllte.common.preprocessing import process_action_space, process_observation_space
from rllte.common.utils import TorchRunningMeanStd, RewardForwardFilter


class BaseReward(ABC):
"""Base class of reward module.
Expand Down Expand Up @@ -61,8 +62,12 @@ def __init__(
obs_norm_type: str = "rms",
) -> None:
# get environment information
self.observation_space = envs.observation_space
self.action_space = envs.action_space
if isinstance(envs, VectorEnv):
self.observation_space = envs.single_observation_space
self.action_space = envs.single_action_space
else:
self.observation_space = envs.observation_space
self.action_space = envs.action_space
self.n_envs = envs.unwrapped.num_envs
## process the observation and action space
self.obs_shape: Tuple = process_observation_space(self.observation_space) # type: ignore
Expand Down Expand Up @@ -138,6 +143,7 @@ def init_normalization(self) -> None:
"""Initialize the normalization parameters for observations if the RMS is used."""
# TODO: better initialization parameters?
num_steps, num_iters = 128, 20
# for the vectorized environments with `Gymnasium2Torch` from rllte
try:
_, _ = self.envs.reset()
if self.obs_norm_type == "rms":
Expand All @@ -157,14 +163,19 @@ def init_normalization(self) -> None:
self.obs_norm.update(all_next_obs)
all_next_obs = []
except:
# for the outdated gym version
# for the normal vectorized environments
_ = self.envs.reset()
if self.obs_norm_type == "rms":
all_next_obs = []
for step in range(num_steps * num_iters):
actions = [self.action_space.sample() for _ in range(self.n_envs)]
actions = np.stack(actions)
next_obs, _, _, _ = self.envs.step(actions)
try:
# for the old gym output
next_obs, _, _, _ = self.envs.step(actions)
except:
# for the new gymnaisum output
next_obs, _, _, _, _ = self.envs.step(actions)
all_next_obs += th.as_tensor(next_obs).view(-1, *self.obs_shape)
# update the running mean and std
if len(all_next_obs) % (num_steps * self.n_envs) == 0:
Expand Down
94 changes: 94 additions & 0 deletions rllte/env/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
Integrating RL environments in RLLTE is incredibly easy and efficient!

## Menu
1. [Installation](#installation)
2. [Usage](#usage)

## Installation

Assuming you are running inside a conda environment.

### Atari
```
pip install ale-py==0.8.1
```

### Craftax

You will need a Jax GPU-enabled conda environment:

```
conda create -n rllte jaxlib==*cuda jax python=3.11 -c conda-forge
pip install craftax
pip install brax
pip install -e .[envs]
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

### DMC
```
pip install dm-control
```

### SuperMarioBros
```
pip install gym-super-mario-bros==7.4.0
```

### Minigrid
```
pip install minigrid
```

### Miniworld
```
pip install miniworld
```

### Procgen
```
pip install procgen
```

### Envpool
```
pip install envpool
```

## Usage

Each environment has a `make_env()` function in `rllte/env/<your_RL_env>/__init__.py` and its necessary wrappers in `rllte/env/<your_RL_env>/wrappers.py`. To add your custom environments, simply follow the same logic as the currently available environments, and the RL training will work flawlessly!

## Example training

```
from rllte.agent import PPO
from rllte.env import (
make_mario_env,
make_envpool_vizdoom_env,
make_envpool_procgen_env,
make_minigrid_env,
make_envpool_atari_env,
make_craftax_env
)
# define params
device = "cuda"
# define environment
env = make_craftax_env(
num_envs=32,
device=device,
)
# define agent
agent = PPO(
env=env,
device=device
)
# start training
agent.train(
num_train_steps=10_000_000,
)
```
17 changes: 15 additions & 2 deletions rllte/env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from .testing import make_multidiscrete_env as make_multidiscrete_env
from .testing import make_box_env as make_box_env
from .testing import make_discrete_env as make_discrete_env

from .utils import make_rllte_env as make_rllte_env

try:
Expand All @@ -52,6 +51,11 @@
except Exception:
pass

try:
from .miniworld import make_miniworld_env as make_miniworld_env
except Exception:
pass

try:
from .procgen import make_envpool_procgen_env as make_envpool_procgen_env
from .procgen import make_procgen_env as make_procgen_env
Expand All @@ -60,6 +64,15 @@

try:
from .mario import make_mario_env as make_mario_env
from .mario import make_mario_multilevel_env as make_mario_multilevel_env
except Exception:
pass

try:
from .craftax import make_craftax_env as make_craftax_env
except Exception:
pass

try:
from .vizdoom import make_envpool_vizdoom_env as make_envpool_vizdoom_env
except Exception:
pass
34 changes: 34 additions & 0 deletions rllte/env/craftax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from craftax.envs.craftax_pixels_env import CraftaxPixelsEnv
from craftax_classic.envs.craftax_pixels_env import CraftaxClassicPixelsEnv
from environment_base.wrappers import (
LogWrapper,
BatchEnvWrapper,
OptimisticResetVecEnvWrapper,
)

from rllte.env.craftax.wrappers import TorchWrapper, ResizeTorchWrapper, RecordEpisodeStatistics4Craftax

def make_craftax_env(
env_id: str = "Craftax-Classic",
num_envs: int = 32,
reset_ratio: int = 16,
device: str = "cpu",
):

if env_id == "Craftax-Classic":
env = CraftaxClassicPixelsEnv()
elif env_id == "Craftax":
env = CraftaxPixelsEnv()
else:
raise ValueError(f"Unknown environment: {env_id}")

env = LogWrapper(env)
env = OptimisticResetVecEnvWrapper(env, num_envs=num_envs, reset_ratio=reset_ratio)
env = TorchWrapper(env, device=device)
env = ResizeTorchWrapper(env, (84, 84))
env = RecordEpisodeStatistics4Craftax(env)
env.num_envs = num_envs
env.single_observation_space = env.observation_space
env.single_action_space = env.action_space
return env

114 changes: 114 additions & 0 deletions rllte/env/craftax/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
import jax
import gymnasium as gym
import torch
from dataclasses import asdict
from brax.io import torch as brax_torch

class TorchWrapper(gym.Wrapper):
"""Wrapper that converts Jax tensors to PyTorch tensors."""

def __init__(self, env, device):
super().__init__(env)
self.device = device
self.env = env
self.default_params = env.default_params
self.metadata = {
'render.modes': ['human', 'rgb_array'],
}

# define obs and action space
obs_shape = env.observation_space(self.default_params).shape
self.observation_space = gym.spaces.Box(
low=-1e6, high=1e6, shape=obs_shape)
self.action_space = gym.spaces.Discrete(env.action_space(self.default_params).n)

# jit the reset function
def reset(key):
key1, key2 = jax.random.split(key)
obs, state = self.env.reset(key2)
return state, obs, key1, asdict(state)
self._reset = jax.jit(reset)

# jit the step function
def step(state, action):
obs, env_state, reward, done, info = self.env.step(rng=self._key, state=state, action=action)
return env_state, obs, reward, done, {**asdict(env_state), **info}
self._step = jax.jit(step)

def reset(self, seed=0, options=None):
self.seed(seed)
self._state, obs, self._key, info = self._reset(self._key)
return brax_torch.jax_to_torch(obs, device=self.device), info

def step(self, action):
action = brax_torch.torch_to_jax(action)
self._state, obs, reward, done, info = self._step(self._state, action)
obs = brax_torch.jax_to_torch(obs, device=self.device)
reward = brax_torch.jax_to_torch(reward, device=self.device)
terminateds = brax_torch.jax_to_torch(done, device=self.device)
truncateds = brax_torch.jax_to_torch(done, device=self.device)
info = brax_torch.jax_to_torch(info, device=self.device)
return obs, reward, terminateds, truncateds, info

def seed(self, seed: int = 0):
self._key = jax.random.PRNGKey(seed)

class ResizeTorchWrapper(gym.Wrapper):
"""Wrapper that resizes observations to a given shape."""

def __init__(self, env, shape):
super().__init__(env)
self.env = env
num_channels = env.observation_space.shape[-1]
self.shape = (num_channels, shape[0], shape[1])

# define obs and action space
self.observation_space = gym.spaces.Box(
low=-1e6, high=1e6, shape=self.shape)

def reset(self, seed=0, options=None):
obs, info = self.env.reset(seed, options)
obs = obs.permute(0, 3, 1, 2)
obs = torch.nn.functional.interpolate(obs, size=self.shape[1:], mode='nearest')
return obs, info

def step(self, action):
obs, reward, terminateds, truncateds, info = self.env.step(action)
obs = obs.permute(0, 3, 1, 2)
obs = torch.nn.functional.interpolate(obs, size=self.shape[1:], mode='nearest')
return obs, reward, terminateds, truncateds, info

class RecordEpisodeStatistics4Craftax(gym.Wrapper):
def __init__(self, env: gym.Env, deque_size: int = 100) -> None:
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None

def reset(self, **kwargs):
observations, infos = super().reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations, infos

def step(self, actions):
observations, rewards, terms, truncs, infos = super().step(actions)
self.episode_returns += rewards.cpu().numpy()
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= 1 - infos["returned_episode"].cpu().numpy().astype(np.int32)
self.episode_lengths *= 1 - infos["returned_episode"].cpu().numpy().astype(np.int32)
infos["episode"] = {}
infos["episode"]["r"] = self.returned_episode_returns
infos["episode"]["l"] = self.returned_episode_lengths

for idx, d in enumerate(terms):
if not d:
infos["episode"]["r"][idx] = 0
infos["episode"]["l"][idx] = 0

return observations, rewards, terms, truncs, infos
Loading

0 comments on commit ef50681

Please sign in to comment.