Skip to content

Commit

Permalink
[RLlib] Add restart-failed-env option to new api stack. (#47608)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Sep 11, 2024
1 parent 03ab9b3 commit e75f5e7
Show file tree
Hide file tree
Showing 8 changed files with 340 additions and 35 deletions.
35 changes: 35 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2961,6 +2961,41 @@ py_test(
args = ["--as-test", "--evaluation-parallel-to-training", "--framework=torch", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-env-runners=3", "--evaluation-duration=211", "--evaluation-duration-unit=timesteps"]
)

# subdirectory: fault_tolerance/
# ....................................
py_test(
name = "examples/fault_tolerance/crashing_cartpole_recreate_failed_env_runners_appo",
main = "examples/fault_tolerance/crashing_and_stalling_env.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/fault_tolerance/crashing_and_stalling_env.py"],
args = ["--algo=APPO", "--enable-new-api-stack", "--as-test", "--stop-reward=450.0"]
)
py_test(
name = "examples/fault_tolerance/crashing_cartpole_restart_failed_envs_appo",
main = "examples/fault_tolerance/crashing_and_stalling_env.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/fault_tolerance/crashing_and_stalling_env.py"],
args = ["--algo=APPO", "--enable-new-api-stack", "--as-test", "--restart-failed-envs", "--stop-reward=450.0"]
)
py_test(
name = "examples/fault_tolerance/crashing_and_stalling_cartpole_restart_failed_envs_ppo",
main = "examples/fault_tolerance/crashing_and_stalling_env.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/fault_tolerance/crashing_and_stalling_env.py"],
args = ["--algo=PPO", "--enable-new-api-stack", "--as-test", "--restart-failed-envs", "--stall", "--stop-reward=450.0"]
)
py_test(
name = "examples/fault_tolerance/crashing_and_stalling_multi_agent_cartpole_restart_failed_envs_ppo",
main = "examples/fault_tolerance/crashing_and_stalling_env.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/fault_tolerance/crashing_and_stalling_env.py"],
args = ["--algo=PPO", "--num-agents=2", "--enable-new-api-stack", "--as-test", "--restart-failed-envs", "--stop-reward=800.0"]
)

# subdirectory: gpus/
# ....................................
py_test(
Expand Down
63 changes: 63 additions & 0 deletions rllib/env/env_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import logging
from typing import Any, Dict, Tuple, TYPE_CHECKING

import gymnasium as gym
Expand All @@ -13,8 +14,13 @@
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig

logger = logging.getLogger("ray.rllib")

tf1, tf, _ = try_import_tf()

ENV_RESET_FAILURE = "env_reset_failure"
ENV_STEP_FAILURE = "env_step_failure"


# TODO (sven): As soon as RolloutWorker is no longer supported, make this base class
# a Checkpointable. Currently, only some of its subclasses are Checkpointables.
Expand Down Expand Up @@ -44,6 +50,8 @@ def __init__(self, *, config: "AlgorithmConfig", **kwargs):
**kwargs: Forward compatibility kwargs.
"""
self.config = config.copy(copy_frozen=False)
self.env = None

super().__init__(**kwargs)

# This eager check is necessary for certain all-framework tests
Expand All @@ -66,6 +74,19 @@ def assert_healthy(self):
AssertionError: If the EnvRunner Actor has NOT been properly initialized.
"""

# TODO: Make this an abstract method that must be implemented.
def make_env(self):
"""Creates the RL environment for this EnvRunner and assigns it to `self.env`.
Note that users should be able to change the EnvRunner's config (e.g. change
`self.config.env_config`) and then call this method to create new environments
with the updated configuration.
It should also be called after a failure of an earlier env in order to clean up
the existing env (for example `close()` it), re-create a new one, and then
continue sampling with that new env.
"""
pass

@abc.abstractmethod
def sample(self, **kwargs) -> Any:
"""Returns experiences (of any form) sampled from this EnvRunner.
Expand Down Expand Up @@ -100,6 +121,48 @@ def __del__(self) -> None:
"""If this Actor is deleted, clears all resources used by it."""
pass

def _try_env_reset(self):
"""Tries resetting the env and - if an error orrurs - handles it gracefully."""
# Try to reset.
try:
obs, infos = self.env.reset()
# Everything ok -> return.
return obs, infos
# Error.
except Exception as e:
# If user wants to simply restart the env -> recreate env and try again
# (calling this method recursively until success).
if self.config.restart_failed_sub_environments:
logger.exception(
"Resetting the env resulted in an error! The original error "
f"is: {e.args[0]}"
)
# Recreate the env and simply try again.
self.make_env()
return self._try_env_reset()
else:
raise e

def _try_env_step(self, actions):
"""Tries stepping the env and - if an error orrurs - handles it gracefully."""
try:
results = self.env.step(actions)
return results
except Exception as e:
if self.config.restart_failed_sub_environments:
logger.exception(
"Stepping the env resulted in an error! The original error "
f"is: {e.args[0]}"
)
# Recreate the env.
self.make_env()
# And return that the stepping failed. The caller will then handle
# specific cleanup operations (for example discarding thus-far collected
# data and repeating the step attempt).
return ENV_STEP_FAILURE
else:
raise e

def _convert_to_tensor(self, struct) -> TensorType:
"""Converts structs to a framework-specific tensor."""

Expand Down
59 changes: 37 additions & 22 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.env import INPUT_ENV_SPACES
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
from ray.rllib.env.utils import _gym_env_creator
Expand Down Expand Up @@ -230,9 +230,10 @@ def _sample_timesteps(
# leak).
self._ongoing_episodes_for_metrics.clear()

# Reset the environment.
# Try resetting the environment.
# TODO (simon): Check, if we need here the seed from the config.
obs, infos = self.env.reset()
obs, infos = self._try_env_reset()

self._cached_to_module = None

# Call `on_episode_start()` callbacks.
Expand Down Expand Up @@ -310,12 +311,21 @@ def _sample_timesteps(
# are the ones stored permanently in the episode objects.
actions = to_env.pop(Columns.ACTIONS)
actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
# Step the environment.

# Try stepping the environment.
# TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env.
# Support vectorized multi-agent envs.
obs, rewards, terminateds, truncateds, infos = self.env.step(
actions_for_env[0]
)
results = self._try_env_step(actions_for_env[0])
# If any failure occurs during stepping -> Throw away all data collected
# thus far and restart sampling procedure.
if results == ENV_STEP_FAILURE:
return self._sample_timesteps(
num_timesteps=num_timesteps,
explore=explore,
random_actions=random_actions,
force_reset=True,
)
obs, rewards, terminateds, truncateds, infos = results
ts += self._increase_sampled_metrics(self.num_envs, obs, self._episode)

# TODO (sven): This simple approach to re-map `to_env` from a
Expand Down Expand Up @@ -376,7 +386,7 @@ def _sample_timesteps(
self._make_on_episode_callback("on_episode_created")

# Reset the environment.
obs, infos = self.env.reset()
obs, infos = self._try_env_reset()
# Add initial observations and infos.
self._episode.add_env_reset(observations=obs, infos=infos)

Expand Down Expand Up @@ -442,9 +452,9 @@ def _sample_episodes(
"agent_to_module_mapping_fn": self.config.policy_mapping_fn,
}

# Reset the environment.
# Try resetting the environment.
# TODO (simon): Check, if we need here the seed from the config.
obs, infos = self.env.reset()
obs, infos = self._try_env_reset()
# Set initial obs and infos in the episodes.
_episode.add_env_reset(observations=obs, infos=infos)
self._make_on_episode_callback("on_episode_start", _episode)
Expand Down Expand Up @@ -507,12 +517,21 @@ def _sample_episodes(
# are the ones stored permanently in the episode objects.
actions = to_env.pop(Columns.ACTIONS)
actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
# Step the environment.

# Try stepping the environment.
# TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env.
# Support vectorized multi-agent envs.
obs, rewards, terminateds, truncateds, infos = self.env.step(
actions_for_env[0]
)
results = self._try_env_step(actions_for_env[0])
# If any failure occurs during stepping -> Throw away all data collected
# thus far and restart sampling procedure.
if results == ENV_STEP_FAILURE:
return self._sample_episodes(
num_episodes=num_episodes,
explore=explore,
random_actions=random_actions,
)
obs, rewards, terminateds, truncateds, infos = results

ts += self._increase_sampled_metrics(self.num_envs, obs, _episode)

# TODO (sven): This simple approach to re-map `to_env` from a
Expand Down Expand Up @@ -587,8 +606,8 @@ def _sample_episodes(
_episode = self._new_episode()
self._make_on_episode_callback("on_episode_created", _episode)

# Reset the environment.
obs, infos = self.env.reset()
# Try resetting the environment.
obs, infos = self._try_env_reset()
# Add initial observations and infos.
_episode.add_env_reset(observations=obs, infos=infos)

Expand Down Expand Up @@ -786,13 +805,8 @@ def assert_healthy(self):
# Make sure, we have built our gym.vector.Env and RLModule properly.
assert self.env and self.module

@override(EnvRunner)
def make_env(self):
"""Creates a MultiAgentEnv (is-a gymnasium env).
Note that users can change the EnvRunner's config (e.g. change
`self.config.env_config`) and then call this method to create new environments
with the updated configuration.
"""
# If an env already exists, try closing it first (to allow it to properly
# cleanup).
if self.env is not None:
Expand All @@ -803,6 +817,7 @@ def make_env(self):
"Tried closing the existing env (multi-agent), but failed with "
f"error: {e.args[0]}"
)
del self.env

env_ctx = self.config.env_config
if not isinstance(env_ctx, EnvContext):
Expand Down
37 changes: 24 additions & 13 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.env import INPUT_ENV_SPACES
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.env.utils import _gym_env_creator
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -248,9 +248,9 @@ def _sample_timesteps(
# leak).
self._ongoing_episodes_for_metrics.clear()

# Reset the environment.
# Try resetting the environment.
# TODO (simon): Check, if we need here the seed from the config.
obs, infos = self.env.reset()
obs, infos = self._try_env_reset()
obs = unbatch(obs)
self._cached_to_module = None

Expand Down Expand Up @@ -317,10 +317,16 @@ def _sample_timesteps(
# are the ones stored permanently in the episode objects.
actions = to_env.pop(Columns.ACTIONS)
actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
# Step the environment.
obs, rewards, terminateds, truncateds, infos = self.env.step(
actions_for_env
)
# Try stepping the environment.
results = self._try_env_step(actions_for_env)
if results == ENV_STEP_FAILURE:
return self._sample_timesteps(
num_timesteps=num_timesteps,
explore=explore,
random_actions=random_actions,
force_reset=True,
)
obs, rewards, terminateds, truncateds, infos = results
obs, actions = unbatch(obs), unbatch(actions)

ts += self.num_envs
Expand Down Expand Up @@ -457,9 +463,9 @@ def _sample_episodes(
# `gymnasium-v1.0.0a2` PR is coming.
_shared_data = {}

# Reset the environment.
# Try resetting the environment.
# TODO (simon): Check, if we need here the seed from the config.
obs, infos = self.env.reset()
obs, infos = self._try_env_reset()
for env_index in range(self.num_envs):
episodes[env_index].add_env_reset(
observation=unbatch(obs)[env_index],
Expand Down Expand Up @@ -514,10 +520,15 @@ def _sample_episodes(
# are the ones stored permanently in the episode objects.
actions = to_env.pop(Columns.ACTIONS)
actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
# Step the environment.
obs, rewards, terminateds, truncateds, infos = self.env.step(
actions_for_env
)
# Try stepping the environment.
results = self._try_env_step(actions_for_env)
if results == ENV_STEP_FAILURE:
return self._sample_episodes(
num_episodes=num_episodes,
explore=explore,
random_actions=random_actions,
)
obs, rewards, terminateds, truncateds, infos = results
obs, actions = unbatch(obs), unbatch(actions)
ts += self.num_envs

Expand Down
7 changes: 7 additions & 0 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,13 @@ def wrap(env):
# not tracking weights versions.
self.weights_seq_no: Optional[int] = None

@override(EnvRunner)
def make_env(self):
# Override this method, b/c it's abstract and must be overridden.
# However, we see no point in implementing it for the old API stack any longer
# (the RolloutWorker class will be deprecated soon).
raise NotImplementedError

@override(EnvRunner)
def assert_healthy(self):
is_healthy = self.policy_map and self.input_reader and self.output_writer
Expand Down
Empty file.
Loading

0 comments on commit e75f5e7

Please sign in to comment.