diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index b77b30cf3c41..255d0ba4ba71 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -140,6 +140,8 @@ def synchronous_parallel_sample( agent_or_env_steps += sum( int(stat_dict[NUM_ENV_STEPS_SAMPLED]) for stat_dict in stats_dicts ) + sample_batches_or_episodes.extend(sampled_data) + all_stats_dicts.extend(stats_dicts) else: for batch_or_episode in sampled_data: if max_agent_steps: @@ -154,9 +156,16 @@ def synchronous_parallel_sample( if _uses_new_env_runners else batch_or_episode.env_steps() ) - sample_batches_or_episodes.extend(sampled_data) - if _return_metrics: - all_stats_dicts.extend(stats_dicts) + sample_batches_or_episodes.append(batch_or_episode) + # Break out (and ignore the remaining samples) if max timesteps (batch + # size) reached. We want to avoid collecting batches that are too large + # only because of a failed/restarted worker causing a second iteration + # of the main loop. + if ( + max_agent_or_env_steps is not None + and agent_or_env_steps >= max_agent_or_env_steps + ): + break if concat is True: # If we have episodes flatten the episode list. diff --git a/rllib/tests/test_node_failure.py b/rllib/tests/test_node_failure.py index 383a9a95d6fa..b2d65be2699f 100644 --- a/rllib/tests/test_node_failure.py +++ b/rllib/tests/test_node_failure.py @@ -1,12 +1,15 @@ -# This workload tests RLlib's ability to recover from failing workers nodes -import time import unittest import ray from ray._private.test_utils import get_other_nodes from ray.cluster_utils import Cluster -from ray.util.state import list_actors -from ray.rllib.algorithms.ppo import PPO, PPOConfig +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + EPISODE_RETURN_MEAN, + LEARNER_RESULTS, +) num_redis_shards = 5 @@ -14,7 +17,6 @@ object_store_memory = 10**8 num_nodes = 3 - assert ( num_nodes * object_store_memory + num_redis_shards * redis_max_memory < ray._private.utils.get_system_memory() / 2 @@ -24,7 +26,7 @@ ) -class NodeFailureTests(unittest.TestCase): +class TestNodeFailures(unittest.TestCase): def setUp(self): # Simulate a cluster on one machine. self.cluster = Cluster() @@ -46,69 +48,74 @@ def tearDown(self): ray.shutdown() self.cluster.shutdown() - def test_continue_training_on_failure(self): - # We tolerate failing workers and pause training + def test_continue_training_on_env_runner_node_failures(self): + # We tolerate failing workers and pause training. config = ( PPOConfig() + .api_stack( + enable_rl_module_and_learner=True, + enable_env_runner_and_connector_v2=True, + ) .environment("CartPole-v1") .env_runners( num_env_runners=6, validate_env_runners_after_construction=True, ) - .fault_tolerance(recreate_failed_env_runners=True) - .training( - train_batch_size=300, + .fault_tolerance( + ignore_env_runner_failures=True, + recreate_failed_env_runners=True, ) ) - ppo = PPO(config=config) - - # One step with all nodes up, enough to satisfy resource requirements - ppo.train() - - self.assertEqual(ppo.env_runner_group.num_healthy_remote_workers(), 6) - self.assertEqual(ppo.env_runner_group.num_remote_workers(), 6) - - # Remove the first non-head node. - node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0] - self.cluster.remove_node(node_to_kill) + algo = config.build() - # step() should continue with 4 rollout workers. - ppo.train() + best_return = 0.0 + for i in range(40): + results = algo.train() + print(f"ITER={i} results={results}") - self.assertEqual(ppo.env_runner_group.num_healthy_remote_workers(), 4) - self.assertEqual(ppo.env_runner_group.num_remote_workers(), 6) - - # node comes back immediately. - self.cluster.add_node( - redis_port=None, - num_redis_shards=None, - num_cpus=2, - num_gpus=0, - object_store_memory=object_store_memory, - redis_max_memory=redis_max_memory, - dashboard_host="0.0.0.0", - ) - - # Now, let's wait for Ray to restart all the RolloutWorker actors. - while True: - states = [ - a["state"] == "ALIVE" - for a in list_actors() - if a["class_name"] == "RolloutWorker" + best_return = max( + best_return, results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] + ) + avg_batch = results[LEARNER_RESULTS][DEFAULT_MODULE_ID][ + "module_train_batch_size_mean" ] - if all(states): - break - # Otherwise, wait a bit. - time.sleep(1) - - # This step should continue with 4 workers, but by the end - # of weight syncing, the 2 recovered rollout workers should - # be back. - ppo.train() - - # Workers should be back up, everything back to normal. - self.assertEqual(ppo.env_runner_group.num_healthy_remote_workers(), 6) - self.assertEqual(ppo.env_runner_group.num_remote_workers(), 6) + self.assertGreaterEqual(avg_batch, config.total_train_batch_size) + self.assertLess( + avg_batch, + config.total_train_batch_size + config.get_rollout_fragment_length(), + ) + + self.assertEqual(algo.env_runner_group.num_remote_workers(), 6) + healthy_env_runners = algo.env_runner_group.num_healthy_remote_workers() + # After node has been removed, we expect 2 workers to be gone. + if (i - 1) % 5 == 0: + self.assertEqual(healthy_env_runners, 4) + # Otherwise, all workers should be there (but might still be in the process + # of coming up). + else: + self.assertIn(healthy_env_runners, [4, 5, 6]) + + # print(f"healthy workers = {algo.env_runner_group.healthy_worker_ids()}") + # Shut down one node every n iterations. + if i % 5 == 0: + to_kill = get_other_nodes(self.cluster, exclude_head=True)[0] + print(f"Killing node {to_kill} ...") + self.cluster.remove_node(to_kill) + + # Bring back a previously failed node. + elif (i - 1) % 5 == 0: + print("Bringing back node ...") + self.cluster.add_node( + redis_port=None, + num_redis_shards=None, + num_cpus=2, + num_gpus=0, + object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + dashboard_host="0.0.0.0", + ) + + self.assertGreaterEqual(best_return, 450.0) if __name__ == "__main__":