Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix large batch size for synchronous algos after EnvRunner failures. #47356

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions rllib/execution/rollout_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def synchronous_parallel_sample(
agent_or_env_steps += sum(
int(stat_dict[NUM_ENV_STEPS_SAMPLED]) for stat_dict in stats_dicts
)
sample_batches_or_episodes.extend(sampled_data)
all_stats_dicts.extend(stats_dicts)
else:
for batch_or_episode in sampled_data:
if max_agent_steps:
Expand All @@ -154,9 +156,16 @@ def synchronous_parallel_sample(
if _uses_new_env_runners
else batch_or_episode.env_steps()
)
sample_batches_or_episodes.extend(sampled_data)
if _return_metrics:
all_stats_dicts.extend(stats_dicts)
sample_batches_or_episodes.append(batch_or_episode)
# Break out (and ignore the remaining samples) if max timesteps (batch
# size) reached. We want to avoid collecting batches that are too large
# only because of a failed/restarted worker causing a second iteration
# of the main loop.
if (
max_agent_or_env_steps is not None
and agent_or_env_steps >= max_agent_or_env_steps
):
break

if concat is True:
# If we have episodes flatten the episode list.
Expand Down
121 changes: 64 additions & 57 deletions rllib/tests/test_node_failure.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# This workload tests RLlib's ability to recover from failing workers nodes
import time
import unittest

import ray
from ray._private.test_utils import get_other_nodes
from ray.cluster_utils import Cluster
from ray.util.state import list_actors
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
LEARNER_RESULTS,
)


num_redis_shards = 5
redis_max_memory = 10**8
object_store_memory = 10**8
num_nodes = 3


assert (
num_nodes * object_store_memory + num_redis_shards * redis_max_memory
< ray._private.utils.get_system_memory() / 2
Expand All @@ -24,7 +26,7 @@
)


class NodeFailureTests(unittest.TestCase):
class TestNodeFailures(unittest.TestCase):
def setUp(self):
# Simulate a cluster on one machine.
self.cluster = Cluster()
Expand All @@ -46,69 +48,74 @@ def tearDown(self):
ray.shutdown()
self.cluster.shutdown()

def test_continue_training_on_failure(self):
# We tolerate failing workers and pause training
def test_continue_training_on_env_runner_node_failures(self):
# We tolerate failing workers and pause training.
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.env_runners(
num_env_runners=6,
validate_env_runners_after_construction=True,
)
.fault_tolerance(recreate_failed_env_runners=True)
.training(
train_batch_size=300,
.fault_tolerance(
ignore_env_runner_failures=True,
recreate_failed_env_runners=True,
)
)
ppo = PPO(config=config)

# One step with all nodes up, enough to satisfy resource requirements
ppo.train()

self.assertEqual(ppo.env_runner_group.num_healthy_remote_workers(), 6)
self.assertEqual(ppo.env_runner_group.num_remote_workers(), 6)

# Remove the first non-head node.
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
self.cluster.remove_node(node_to_kill)
algo = config.build()

# step() should continue with 4 rollout workers.
ppo.train()
best_return = 0.0
for i in range(40):
results = algo.train()
print(f"ITER={i} results={results}")

self.assertEqual(ppo.env_runner_group.num_healthy_remote_workers(), 4)
self.assertEqual(ppo.env_runner_group.num_remote_workers(), 6)

# node comes back immediately.
self.cluster.add_node(
redis_port=None,
num_redis_shards=None,
num_cpus=2,
num_gpus=0,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,
dashboard_host="0.0.0.0",
)

# Now, let's wait for Ray to restart all the RolloutWorker actors.
while True:
states = [
a["state"] == "ALIVE"
for a in list_actors()
if a["class_name"] == "RolloutWorker"
best_return = max(
best_return, results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
)
avg_batch = results[LEARNER_RESULTS][DEFAULT_MODULE_ID][
"module_train_batch_size_mean"
]
if all(states):
break
# Otherwise, wait a bit.
time.sleep(1)

# This step should continue with 4 workers, but by the end
# of weight syncing, the 2 recovered rollout workers should
# be back.
ppo.train()

# Workers should be back up, everything back to normal.
self.assertEqual(ppo.env_runner_group.num_healthy_remote_workers(), 6)
self.assertEqual(ppo.env_runner_group.num_remote_workers(), 6)
self.assertGreaterEqual(avg_batch, config.total_train_batch_size)
self.assertLess(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

avg_batch,
config.total_train_batch_size + config.get_rollout_fragment_length(),
)

self.assertEqual(algo.env_runner_group.num_remote_workers(), 6)
healthy_env_runners = algo.env_runner_group.num_healthy_remote_workers()
# After node has been removed, we expect 2 workers to be gone.
if (i - 1) % 5 == 0:
self.assertEqual(healthy_env_runners, 4)
# Otherwise, all workers should be there (but might still be in the process
# of coming up).
else:
self.assertIn(healthy_env_runners, [4, 5, 6])

# print(f"healthy workers = {algo.env_runner_group.healthy_worker_ids()}")
# Shut down one node every n iterations.
if i % 5 == 0:
to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
print(f"Killing node {to_kill} ...")
self.cluster.remove_node(to_kill)

# Bring back a previously failed node.
elif (i - 1) % 5 == 0:
print("Bringing back node ...")
self.cluster.add_node(
redis_port=None,
num_redis_shards=None,
num_cpus=2,
num_gpus=0,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,
dashboard_host="0.0.0.0",
)

self.assertGreaterEqual(best_return, 450.0)


if __name__ == "__main__":
Expand Down
Loading