From 4783f3362210b005055084c51691fb42286f97a9 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 11 Dec 2024 17:33:41 +0100 Subject: [PATCH] [RLlib] Cleanup examples folder (new API stack) vol 31: Add hierarchical training example script. (#49127) --- doc/source/rllib/rllib-examples.rst | 10 + rllib/BUILD | 17 +- rllib/examples/envs/classes/six_room_env.py | 315 ++++++++++++++++++ .../hierarchical/hierarchical_training.py | 272 ++++++++------- rllib/utils/test_utils.py | 10 +- 5 files changed, 485 insertions(+), 139 deletions(-) create mode 100644 rllib/examples/envs/classes/six_room_env.py diff --git a/doc/source/rllib/rllib-examples.rst b/doc/source/rllib/rllib-examples.rst index c5f58ea523828..878ce0e709c67 100644 --- a/doc/source/rllib/rllib-examples.rst +++ b/doc/source/rllib/rllib-examples.rst @@ -186,6 +186,16 @@ GPU (for Training and Sampling) with performance improvements during evaluation. +Hierarchical Training ++++++++++++++++++++++ + +- `Hierarchical RL Training `__: + Showcases a hierarchical RL setup inspired by automatic subgoal discovery and subpolicy specialization. A high-level policy selects subgoals and assigns one of three + specialized low-level policies to achieve them within a time limit, encouraging specialization and efficient task-solving. + The agent has to navigate a complex grid-world environment. The example highlights the advantages of hierarchical + learning over flat approaches by demonstrating significantly improved learning performance in challenging, goal-oriented tasks. + + Inference (of Models/Policies) ++++++++++++++++++++++++++++++ diff --git a/rllib/BUILD b/rllib/BUILD index 1592a8bb4222f..ef46f2177719a 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2600,24 +2600,15 @@ py_test( # subdirectory: hierarchical/ # .................................... -#@OldAPIStack -py_test( - name = "examples/hierarchical/hierarchical_training_tf", - main = "examples/hierarchical/hierarchical_training.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/hierarchical/hierarchical_training.py"], - args = [ "--framework=tf", "--stop-reward=0.0"] -) - -#@OldAPIStack +# TODO (sven): Add this script to the release tests as well. The problem is too hard to be solved +# in < 10min on a few CPUs. py_test( - name = "examples/hierarchical/hierarchical_training_torch", + name = "examples/hierarchical/hierarchical_training", main = "examples/hierarchical/hierarchical_training.py", tags = ["team:rllib", "exclusive", "examples"], size = "medium", srcs = ["examples/hierarchical/hierarchical_training.py"], - args = ["--framework=torch", "--stop-reward=0.0"] + args = ["--enable-new-api-stack", "--stop-iters=5", "--map=small", "--time-limit=100", "--max-steps-low-level=15"] ) # subdirectory: inference/ diff --git a/rllib/examples/envs/classes/six_room_env.py b/rllib/examples/envs/classes/six_room_env.py new file mode 100644 index 0000000000000..2a4b1a2a41d51 --- /dev/null +++ b/rllib/examples/envs/classes/six_room_env.py @@ -0,0 +1,315 @@ +import gymnasium as gym + +from ray.rllib.env.multi_agent_env import MultiAgentEnv + + +# Map representation: Always six rooms (as the name suggests) with doors in between. +MAPS = { + "small": [ + "WWWWWWWWWWWWW", + "W W W W", + "W W W", + "W W W", + "W WWWW WWWW W", + "W W W W", + "W W W", + "W W GW", + "WWWWWWWWWWWWW", + ], + "medium": [ + "WWWWWWWWWWWWWWWWWWW", + "W W W W", + "W W W", + "W W W", + "W WWWWWWW WWWWWWW W", + "W W W W", + "W W W", + "W W GW", + "WWWWWWWWWWWWWWWWWWW", + ], + "large": [ + "WWWWWWWWWWWWWWWWWWWWWWWWW", + "W W W W", + "W W W W", + "W W W", + "W W W", + "W W W W", + "WW WWWWWWWWW WWWWWWWWWW W", + "W W W W", + "W W W", + "W W W W", + "W W W", + "W W W GW", + "WWWWWWWWWWWWWWWWWWWWWWWWW", + ], +} + + +class SixRoomEnv(gym.Env): + """A grid-world with six rooms (arranged as 2x3), which are connected by doors. + + The agent starts in the upper left room and has to reach a designated goal state + in one of the rooms using primitive actions up, left, down, and right. + + The agent receives a small penalty of -0.01 on each step and a reward of +10.0 when + reaching the goal state. + """ + + def __init__(self, config=None): + super().__init__() + + # User can provide a custom map or a recognized map name (small, medium, large). + self.map = config.get("custom_map", MAPS.get(config.get("map"), MAPS["small"])) + self.time_limit = config.get("time_limit", 50) + + # Define observation space: Discrete, index fields. + self.observation_space = gym.spaces.Discrete(len(self.map) * len(self.map[0])) + # Primitive actions: up, down, left, right. + self.action_space = gym.spaces.Discrete(4) + + # Initialize environment state. + self.reset() + + def reset(self, *, seed=None, options=None): + self._agent_pos = (1, 1) + self._ts = 0 + # Return high-level observation. + return self._agent_discrete_pos, {} + + def step(self, action): + next_pos = _get_next_pos(action, self._agent_pos) + + self._ts += 1 + + # Check if the move ends up in a wall. If so -> Ignore the move and stay + # where we are right now. + if self.map[next_pos[0]][next_pos[1]] != "W": + self._agent_pos = next_pos + + # Check if the agent has reached the global goal state. + if self.map[self._agent_pos[0]][self._agent_pos[1]] == "G": + return self._agent_discrete_pos, 10.0, True, False, {} + + # Small step penalty. + return self._agent_discrete_pos, -0.01, False, self._ts >= self.time_limit, {} + + @property + def _agent_discrete_pos(self): + x = self._agent_pos[0] + y = self._agent_pos[1] + # discrete position = row idx * columns + col idx + return x * len(self.map[0]) + y + + +class HierarchicalSixRoomEnv(MultiAgentEnv): + def __init__(self, config=None): + super().__init__() + + # User can provide a custom map or a recognized map name (small, medium, large). + self.map = config.get("custom_map", MAPS.get(config.get("map"), MAPS["small"])) + self.max_steps_low_level = config.get("max_steps_low_level", 15) + self.time_limit = config.get("time_limit", 50) + self.num_low_level_agents = config.get("num_low_level_agents", 3) + + self.agents = self.possible_agents = ["high_level_agent"] + [ + f"low_level_agent_{i}" for i in range(self.num_low_level_agents) + ] + + # Define basic observation space: Discrete, index fields. + observation_space = gym.spaces.Discrete(len(self.map) * len(self.map[0])) + # Low level agents always see where they are right now and what the target + # state should be. + low_level_observation_space = gym.spaces.Tuple( + (observation_space, observation_space) + ) + # Primitive actions: up, down, left, right. + low_level_action_space = gym.spaces.Discrete(4) + + self.observation_spaces = {"high_level_agent": observation_space} + self.observation_spaces.update( + { + f"low_level_agent_{i}": low_level_observation_space + for i in range(self.num_low_level_agents) + } + ) + self.action_spaces = { + "high_level_agent": gym.spaces.Tuple( + ( + # The new target observation. + observation_space, + # Low-level policy that should get us to the new target observation. + gym.spaces.Discrete(self.num_low_level_agents), + ) + ) + } + self.action_spaces.update( + { + f"low_level_agent_{i}": low_level_action_space + for i in range(self.num_low_level_agents) + } + ) + + # Initialize environment state. + self.reset() + + def reset(self, *, seed=None, options=None): + self._agent_pos = (1, 1) + self._low_level_steps = 0 + self._high_level_action = None + # Number of times the low-level agent reached the given target (by the high + # level agent). + self._num_targets_reached = 0 + + self._ts = 0 + + # Return high-level observation. + return { + "high_level_agent": self._agent_discrete_pos, + }, {} + + def step(self, action_dict): + self._ts += 1 + + terminateds = {"__all__": self._ts >= self.time_limit} + truncateds = {"__all__": False} + + # High-level agent acted: Set next goal and next low-level policy to use. + # Note that the agent does not move in this case and stays at its current + # location. + if "high_level_agent" in action_dict: + self._high_level_action = action_dict["high_level_agent"] + low_level_agent = f"low_level_agent_{self._high_level_action[1]}" + self._low_level_steps = 0 + # Return next low-level observation for the now-active agent. + # We want this agent to act next. + return ( + { + low_level_agent: ( + self._agent_discrete_pos, # current + self._high_level_action[0], # target + ) + }, + # Penalty for a target state that's close to the current state. + { + "high_level_agent": ( + self.eucl_dist( + self._agent_discrete_pos, + self._high_level_action[0], + self.map, + ) + / (len(self.map) ** 2 + len(self.map[0]) ** 2) ** 0.5 + ) + - 1.0, + }, + terminateds, + truncateds, + {}, + ) + # Low-level agent made a move (primitive action). + else: + assert len(action_dict) == 1 + + # Increment low-level step counter. + self._low_level_steps += 1 + + target_discrete_pos, low_level_agent = self._high_level_action + low_level_agent = f"low_level_agent_{low_level_agent}" + next_pos = _get_next_pos(action_dict[low_level_agent], self._agent_pos) + + # Check if the move ends up in a wall. If so -> Ignore the move and stay + # where we are right now. + if self.map[next_pos[0]][next_pos[1]] != "W": + self._agent_pos = next_pos + + # Check if the agent has reached the global goal state. + if self.map[self._agent_pos[0]][self._agent_pos[1]] == "G": + rewards = { + "high_level_agent": 10.0, + # +1.0 if the goal position was also the target position for the + # low level agent. + low_level_agent: float( + self._agent_discrete_pos == target_discrete_pos + ), + } + terminateds["__all__"] = True + return ( + {"high_level_agent": self._agent_discrete_pos}, + rewards, + terminateds, + truncateds, + {}, + ) + + # Low-level agent has reached its target location (given by the high-level): + # - Hand back control to high-level agent. + # - Reward low level agent and high-level agent with small rewards. + elif self._agent_discrete_pos == target_discrete_pos: + self._num_targets_reached += 1 + rewards = { + "high_level_agent": 1.0, + low_level_agent: 1.0, + } + return ( + {"high_level_agent": self._agent_discrete_pos}, + rewards, + terminateds, + truncateds, + {}, + ) + + # Low-level agent has not reached anything. + else: + # Small step penalty for low-level agent. + rewards = {low_level_agent: -0.01} + # Reached time budget -> Hand back control to high level agent. + if self._low_level_steps >= self.max_steps_low_level: + rewards["high_level_agent"] = -0.01 + return ( + {"high_level_agent": self._agent_discrete_pos}, + rewards, + terminateds, + truncateds, + {}, + ) + else: + return ( + { + low_level_agent: ( + self._agent_discrete_pos, # current + target_discrete_pos, # target + ), + }, + rewards, + terminateds, + truncateds, + {}, + ) + + @property + def _agent_discrete_pos(self): + x = self._agent_pos[0] + y = self._agent_pos[1] + # discrete position = row idx * columns + col idx + return x * len(self.map[0]) + y + + @staticmethod + def eucl_dist(pos1, pos2, map): + x1, y1 = pos1 % len(map[0]), pos1 // len(map) + x2, y2 = pos2 % len(map[0]), pos2 // len(map) + return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5 + + +def _get_next_pos(action, pos): + x, y = pos + # Up. + if action == 0: + return x - 1, y + # Down. + elif action == 1: + return x + 1, y + # Left. + elif action == 2: + return x, y - 1 + # Right. + else: + return x, y + 1 diff --git a/rllib/examples/hierarchical/hierarchical_training.py b/rllib/examples/hierarchical/hierarchical_training.py index ccdc067fe3ae0..d7401e262ea61 100644 --- a/rllib/examples/hierarchical/hierarchical_training.py +++ b/rllib/examples/hierarchical/hierarchical_training.py @@ -1,152 +1,182 @@ -# @OldAPIStack - -"""Example of hierarchical training using the multi-agent API. - -The example env is that of a "windy maze". The agent observes the current wind -direction and can either choose to stand still, or move in that direction. - -You can try out the env directly with: - - $ python hierarchical_training.py --flat - -A simple hierarchical formulation involves a high-level agent that issues goals -(i.e., go north / south / east / west), and a low-level agent that executes -these goals over a number of time-steps. This can be implemented as a -multi-agent environment with a top-level agent and low-level agents spawned -for each higher-level action. The lower level agent is rewarded for moving -in the right direction. - -You can try this formulation with: - - $ python hierarchical_training.py # gets ~100 rew after ~100k timesteps - -Note that the hierarchical formulation actually converges slightly slower than -using --flat in this example. +"""Example of running a hierarchichal training setup in RLlib using its multi-agent API. + +This example is very loosely based on this paper: +[1] Hierarchical RL Based on Subgoal Discovery and Subpolicy Specialization - +B. Bakker & J. Schmidhuber - 2003 + +The approach features one high level policy, which picks the next target state to be +reached by one of three low level policies as well as the actual low level policy to +take over control. +A low level policy - once chosen by the high level one - has up to 10 primitive +timesteps to reach the given target state. If it reaches it, both high level and low +level policy are rewarded and the high level policy takes another action (choses a new +target state and a new low level policy). +A global goal state must be reached to deem the overall task to be solved. Once one +of the lower level policies reaches that goal state, the high level policy receives +a large reward and the episode ends. +The approach utilizes the possibility for low level policies to specialize in reaching +certain sub-goals and the high level policy to know, which sub goals to pick next and +which "expert" (low level policy) to allow to reach the subgoal. + +This example: + - demonstrates how to write a relatively simple custom multi-agent environment and + have it behave, such that it mimics a hierarchical RL setup with higher- and lower + level agents acting on different abstract time axes (the higher level policy + only acts occasionally, picking a new lower level policy and the lower level + policies have each n primitive timesteps to reach the given target state, after + which control is handed back to the high level policy for the next pick). + - shows how to setup a plain multi-agent RL algo (here: PPO) to learn in this + hierarchical setup and solve tasks that are otherwise very difficult to solve + only with a single, primitive-action picking low level policy. + +We use the `SixRoomEnv` and `HierarchicalSixRoomEnv`, both sharing the same built-in +maps. The envs are similar to the FrozenLake-v1 env, but support walls (inner and outer) +through which the agent cannot walk. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --map=large --time-limit=50` + +Use the `--flat` option to disable the hierarchical setup and learn the simple (flat) +SixRoomEnv with only one policy. You should observe that it's much harder for the algo +to reach the global goal state in this setting. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +In the console output, you can see that only a PPO algorithm that uses hierarchical +training (`--flat` flag is NOT set) can actually learn with the command line options +`--map=large --time-limit=500 --max-steps-low-level=40 --num-low-level-agents=3`. + +4 policies in a hierarchical setup (1 high level "manager", 3 low level "experts"): ++---------------------+----------+--------+------------------+ +| Trial name | status | iter | total time (s) | +| | | | | +|---------------------+----------+--------+------------------+ +| PPO_env_58b78_00000 | RUNNING | 100 | 278.23 | ++---------------------+----------+--------+------------------+ ++-------------------+--------------------------+---------------------------+ ... +| combined return | return high_level_policy | return low_level_policy_0 | +|-------------------+--------------------------+---------------------------+ ... +| -8.4 | -5.2 | -1.19 | ++-------------------+--------------------------+---------------------------+ ... """ - -import argparse -from gymnasium.spaces import Discrete, Tuple -import logging -import os - -import ray -from ray import air, tune -from ray.air.constants import TRAINING_ITERATION +from ray import tune from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.examples.envs.classes.windy_maze_env import ( - WindyMazeEnv, - HierarchicalWindyMazeEnv, +from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations +from ray.rllib.examples.envs.classes.six_room_env import ( + HierarchicalSixRoomEnv, + SixRoomEnv, ) -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - NUM_ENV_STEPS_SAMPLED_LIFETIME, +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, ) -from ray.rllib.utils.test_utils import check_learning_achieved -parser = argparse.ArgumentParser() -parser.add_argument("--flat", action="store_true") -parser.add_argument( - "--framework", - choices=["tf", "tf2", "torch"], - default="torch", - help="The DL framework specifier.", +parser = add_rllib_example_script_args( + default_reward=7.0, + default_timesteps=4000000, + default_iters=800, ) parser.add_argument( - "--as-test", + "--flat", action="store_true", - help="Whether this script should be run as a test: --stop-reward must " - "be achieved within --stop-timesteps AND --stop-iters.", + help="Use the non-hierarchical, single-agent flat `SixRoomEnv` instead.", ) parser.add_argument( - "--stop-iters", type=int, default=200, help="Number of iterations to train." + "--map", + type=str, + choices=["small", "medium", "large"], + default="medium", + help="The built-in map to use.", ) parser.add_argument( - "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train." + "--time-limit", + type=int, + default=100, + help="The max. number of (primitive) timesteps per episode.", ) parser.add_argument( - "--stop-reward", type=float, default=0.0, help="Reward at which we stop training." + "--max-steps-low-level", + type=int, + default=15, + help="The max. number of steps a low-level policy can take after having been " + "picked by the high level policy. After this number of timesteps, control is " + "handed back to the high-level policy (to pick a next goal position plus the next " + "low level policy).", ) parser.add_argument( - "--local-mode", - action="store_true", - help="Init Ray in local mode for easier debugging.", + "--num-low-level-agents", + type=int, + default=3, + help="The number of low-level agents/policies to use.", ) +parser.set_defaults(enable_new_api_stack=True) -logger = logging.getLogger(__name__) if __name__ == "__main__": args = parser.parse_args() - ray.init(local_mode=args.local_mode) - - stop = { - TRAINING_ITERATION: args.stop_iters, - NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, - } + # Run the flat (non-hierarchical env). if args.flat: - results = tune.Tuner( - "PPO", - run_config=air.RunConfig(stop=stop), - param_space=( - PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .environment(WindyMazeEnv) - .env_runners(num_env_runners=0) - .framework(args.framework) - ).to_dict(), - ).fit() + cls = SixRoomEnv + # Run in hierarchical mode. else: - maze = WindyMazeEnv(None) + cls = HierarchicalSixRoomEnv + + tune.register_env("env", lambda cfg: cls(config=cfg)) + + base_config = ( + PPOConfig() + .environment( + "env", + env_config={ + "map": args.map, + "max_steps_low_level": args.max_steps_low_level, + "time_limit": args.time_limit, + "num_low_level_agents": args.num_low_level_agents, + }, + ) + .env_runners( + # num_envs_per_env_runner=10, + env_to_module_connector=( + lambda env: FlattenObservations(multi_agent=not args.flat) + ), + ) + .training( + train_batch_size_per_learner=4000, + minibatch_size=512, + lr=0.0003, + num_epochs=20, + entropy_coeff=0.025, + ) + ) + + # Configure a proper multi-agent setup for the hierarchical env. + if not args.flat: - def policy_mapping_fn(agent_id, episode, worker, **kwargs): + def policy_mapping_fn(agent_id, episode, **kwargs): + # Map each low level agent to its respective (low-level) policy. if agent_id.startswith("low_level_"): - return "low_level_policy" + return f"low_level_policy_{agent_id[-1]}" + # Map the high level agent to the high level policy. else: return "high_level_policy" - config = ( - PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .environment(HierarchicalWindyMazeEnv) - .framework(args.framework) - .env_runners(num_env_runners=0) - .training(entropy_coeff=0.01) - .multi_agent( - policies={ - "high_level_policy": ( - None, - maze.observation_space, - Discrete(4), - PPOConfig.overrides(gamma=0.9), - ), - "low_level_policy": ( - None, - Tuple([maze.observation_space, Discrete(4)]), - maze.action_space, - PPOConfig.overrides(gamma=0.0), - ), - }, - policy_mapping_fn=policy_mapping_fn, - ) - # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) + base_config.multi_agent( + policy_mapping_fn=policy_mapping_fn, + policies={"high_level_policy"} + | {f"low_level_policy_{i}" for i in range(args.num_low_level_agents)}, ) - results = tune.Tuner( - "PPO", - param_space=config.to_dict(), - run_config=air.RunConfig(stop=stop, verbose=1), - ).fit() - - if args.as_test: - check_learning_achieved(results, args.stop_reward) - - ray.shutdown() + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 0973188d88487..0be84e9cf23eb 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -25,7 +25,7 @@ import tree # pip install dm_tree import ray -from ray import air, tune +from ray import train, tune from ray.air.constants import TRAINING_ITERATION from ray.air.integrations.wandb import WandbLoggerCallback, WANDB_ENV_VAR from ray.rllib.core import DEFAULT_MODULE_ID, Columns @@ -1249,11 +1249,11 @@ def run_rllib_example_script_experiment( results = tune.Tuner( trainable or config.algo_class, param_space=config, - run_config=air.RunConfig( + run_config=train.RunConfig( stop=stop, verbose=args.verbose, callbacks=tune_callbacks, - checkpoint_config=air.CheckpointConfig( + checkpoint_config=train.CheckpointConfig( checkpoint_frequency=args.checkpoint_freq, checkpoint_at_end=args.checkpoint_at_end, ), @@ -1471,14 +1471,14 @@ def check_reproducibilty( results1 = tune.Tuner( algo_class, param_space=algo_config.to_dict(), - run_config=air.RunConfig(stop=stop_dict, verbose=1), + run_config=train.RunConfig(stop=stop_dict, verbose=1), ).fit() results1 = results1.get_best_result().metrics results2 = tune.Tuner( algo_class, param_space=algo_config.to_dict(), - run_config=air.RunConfig(stop=stop_dict, verbose=1), + run_config=train.RunConfig(stop=stop_dict, verbose=1), ).fit() results2 = results2.get_best_result().metrics