diff --git a/.travis.yml b/.travis.yml index b48089d52ecd..5f45138cbdda 100644 --- a/.travis.yml +++ b/.travis.yml @@ -131,6 +131,21 @@ script: # module is only found if the test directory is in the PYTHONPATH. - export PYTHONPATH="$PYTHONPATH:./test/" + # ray tune tests + - python python/ray/tune/test/dependency_test.py + - python -m pytest -v python/ray/tune/test/trial_runner_test.py + - python -m pytest -v python/ray/tune/test/trial_scheduler_test.py + - python -m pytest -v python/ray/tune/test/experiment_test.py + - python -m pytest -v python/ray/tune/test/tune_server_test.py + - python -m pytest -v python/ray/tune/test/ray_trial_executor_test.py + - python -m pytest -v python/ray/tune/test/automl_searcher_test.py + + # ray rllib tests + - python -m pytest -v python/ray/rllib/test/test_catalog.py + - python -m pytest -v python/ray/rllib/test/test_filters.py + - python -m pytest -v python/ray/rllib/test/test_optimizers.py + - python -m pytest -v python/ray/rllib/test/test_evaluators.py + - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py - python -m pytest -v python/ray/test/test_ray_init.py @@ -153,21 +168,6 @@ script: - python -m pytest -v test/credis_test.py - python -m pytest -v test/node_manager_test.py - # ray tune tests - - python python/ray/tune/test/dependency_test.py - - python -m pytest -v python/ray/tune/test/trial_runner_test.py - - python -m pytest -v python/ray/tune/test/trial_scheduler_test.py - - python -m pytest -v python/ray/tune/test/experiment_test.py - - python -m pytest -v python/ray/tune/test/tune_server_test.py - - python -m pytest -v python/ray/tune/test/ray_trial_executor_test.py - - python -m pytest -v python/ray/tune/test/automl_searcher_test.py - - # ray rllib tests - - python -m pytest -v python/ray/rllib/test/test_catalog.py - - python -m pytest -v python/ray/rllib/test/test_filters.py - - python -m pytest -v python/ray/rllib/test/test_optimizers.py - - python -m pytest -v python/ray/rllib/test/test_evaluators.py - # ray temp file tests - python -m pytest -v test/tempfile_test.py diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index ca36186e1a5f..37ea011a0b5c 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -24,27 +24,39 @@ ARS **Yes** **Yes** No No .. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces -In the high-level agent APIs, environments are identified with string names. By default, the string will be interpreted as a gym `environment name `__, however you can also register custom environments by name: +You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name `__. Custom env classes must take a single ``env_config`` parameter in their constructor: .. code-block:: python import ray - from ray.tune.registry import register_env from ray.rllib.agents import ppo - def env_creator(env_config): - import gym - return gym.make("CartPole-v0") # or return your own custom env + class MyEnv(gym.Env): + def __init__(self, env_config): + self.action_space = ... + self.observation_space = ... + ... - register_env("my_env", env_creator) ray.init() - trainer = ppo.PPOAgent(env="my_env", config={ - "env_config": {}, # config to pass to env creator + trainer = ppo.PPOAgent(env=MyEnv, config={ + "env_config": {}, # config to pass to env class }) while True: print(trainer.train()) +You can also register a custom env creator function with a string name. This function must take a single ``env_config`` parameter and return an env instance: + +.. code-block:: python + + from ray.tune.registry import register_env + + def env_creator(env_config): + return MyEnv(...) # return an env instance + + register_env("my_env", env_creator) + trainer = ppo.PPOAgent(env="my_env") + Configuring Environments ------------------------ diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index dc37d22943ba..e647b0a2791f 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -224,35 +224,6 @@ Sometimes, it is necessary to coordinate between pieces of code that live in dif Ray actors provide high levels of performance, so in more complex cases they can be used implement communication patterns such as parameter servers and allreduce. -Debugging ---------- - -Gym Monitor -~~~~~~~~~~~ - -The ``"monitor": true`` config can be used to save Gym episode videos to the result dir. For example: - -.. code-block:: bash - - python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \ - --run=A2C --config '{"num_workers": 2, "monitor": true}' - - # videos will be saved in the ~/ray_results/ dir, for example - openaigym.video.0.31401.video000000.meta.json - openaigym.video.0.31401.video000000.mp4 - openaigym.video.0.31403.video000000.meta.json - openaigym.video.0.31403.video000000.mp4 - -Log Verbosity -~~~~~~~~~~~~~ - -You can control the agent log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example: - -.. code-block:: bash - - python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \ - --run=A2C --config '{"num_workers": 2, "log_level": "DEBUG"}' - Callbacks and Custom Metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -278,6 +249,10 @@ You can provide callback functions to be called at points during policy evaluati episode.episode_id, episode.length, mean_pole_angle)) episode.custom_metrics["mean_pole_angle"] = mean_pole_angle + def on_train_result(info): + print("agent.train() result: {} -> {} episodes".format( + info["agent"].__name__, info["result"]["episodes_this_iter"])) + ray.init() trials = tune.run_experiments({ "test": { @@ -288,6 +263,7 @@ You can provide callback functions to be called at points during policy evaluati "on_episode_start": tune.function(on_episode_start), "on_episode_step": tune.function(on_episode_step), "on_episode_end": tune.function(on_episode_end), + "on_train_result": tune.function(on_train_result), }, }, } @@ -297,6 +273,113 @@ Custom metrics can be accessed and visualized like any other training result: .. image:: custom_metric.png +Example: Curriculum Learning +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Let's look at two ways to use the above APIs to implement `curriculum learning `__. In curriculum learning, the agent task is adjusted over time to improve the learning process. Suppose that we have an environment class with a ``set_phase()`` method that we can call to adjust the task difficulty over time: + +Approach 1: Use the Agent API and update the environment between calls to ``train()``. This example shows the agent being run inside a Tune function: + +.. code-block:: python + + import ray + from ray import tune + from ray.rllib.agents.ppo import PPOAgent + + def train(config, reporter): + agent = PPOAgent(config=config, env=YourEnv) + while True: + result = agent.train() + reporter(**result) + if result["episode_reward_mean"] > 200: + phase = 2 + elif result["episode_reward_mean"] > 100: + phase = 1 + else: + phase = 0 + agent.optimizer.foreach_evaluator(lambda ev: ev.env.set_phase(phase)) + + ray.init() + tune.run_experiments({ + "curriculum": { + "run": train, + "config": { + "num_gpus": 0, + "num_workers": 2, + }, + "trial_resources": { + "cpu": 1, + "gpu": lambda spec: spec.config.num_gpus, + "extra_cpu": lambda spec: spec.config.num_workers, + }, + }, + }) + +Approach 2: Use the callbacks API to update the environment on new training results: + +.. code-block:: python + + import ray + from ray import tune + + def on_train_result(info): + result = info["result"] + if result["episode_reward_mean"] > 200: + phase = 2 + elif result["episode_reward_mean"] > 100: + phase = 1 + else: + phase = 0 + agent = info["agent"] + agent.optimizer.foreach_evaluator(lambda ev: ev.env.set_phase(phase)) + + ray.init() + tune.run_experiments({ + "curriculum": { + "run": "PPO", + "env": YourEnv, + "config": { + "callbacks": { + "on_train_result": tune.function(on_train_result), + }, + }, + }, + }) + +Debugging +--------- + +Gym Monitor +~~~~~~~~~~~ + +The ``"monitor": true`` config can be used to save Gym episode videos to the result dir. For example: + +.. code-block:: bash + + python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \ + --run=A2C --config '{"num_workers": 2, "monitor": true}' + + # videos will be saved in the ~/ray_results/ dir, for example + openaigym.video.0.31401.video000000.meta.json + openaigym.video.0.31401.video000000.mp4 + openaigym.video.0.31403.video000000.meta.json + openaigym.video.0.31403.video000000.mp4 + +Log Verbosity +~~~~~~~~~~~~~ + +You can control the agent log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example: + +.. code-block:: bash + + python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \ + --run=A2C --config '{"num_workers": 2, "log_level": "DEBUG"}' + +Stack Traces +~~~~~~~~~~~~ + +You can use the ``ray stack`` command to dump the stack traces of all the Python workers on a single node. This can be useful for debugging unexpected hangs or performance issues. + REST API -------- diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index f0d9510756b9..f25ff32ab5f8 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -2,12 +2,13 @@ from __future__ import division from __future__ import print_function +from datetime import datetime import copy -import os import logging +import os import pickle +import six import tempfile -from datetime import datetime import tensorflow as tf import ray @@ -15,7 +16,7 @@ from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils import FilterManager, deep_update, merge_dicts -from ray.tune.registry import ENV_CREATOR, _global_registry +from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.trainable import Trainable from ray.tune.trial import Resources from ray.tune.logger import UnifiedLogger @@ -40,6 +41,7 @@ "on_episode_step": None, # arg: {"env": .., "episode": ...} "on_episode_end": None, # arg: {"env": .., "episode": ...} "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} + "on_train_result": None, # arg: {"agent": ..., "result": ...} }, # === Policy === @@ -274,7 +276,7 @@ def __init__(self, config=None, env=None, logger_creator=None): self.global_vars = {"timestep": 0} # Agents allow env ids to be passed directly to the constructor. - self._env_id = env or config.get("env") + self._env_id = _register_if_needed(env or config.get("env")) # Create a default logger creator if no logger_creator is specified if logger_creator is None: @@ -316,7 +318,13 @@ def train(self): logger.debug("synchronized filters: {}".format( self.local_evaluator.filters)) - return Trainable.train(self) + result = Trainable.train(self) + if self.config["callbacks"].get("on_train_result"): + self.config["callbacks"]["on_train_result"]({ + "agent": self, + "result": result, + }) + return result def _setup(self, config): env = self._env_id @@ -444,6 +452,15 @@ def _restore(self, checkpoint_path): self.__setstate__(extra_data) +def _register_if_needed(env_object): + if isinstance(env_object, six.string_types): + return env_object + elif isinstance(env_object, type): + name = env_object.__name__ + register_env(name, lambda config: env_object(config)) + return name + + def get_agent_class(alg): """Returns the class of a known agent given its name.""" diff --git a/python/ray/rllib/examples/carla/a3c_lane_keep.py b/python/ray/rllib/examples/carla/a3c_lane_keep.py deleted file mode 100644 index 9629808ba4c7..000000000000 --- a/python/ray/rllib/examples/carla/a3c_lane_keep.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import register_env, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import LANE_KEEP - -env_name = "carla_env" -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "reward_function": "lane_keep", - "enable_planner": False, - "scenarios": [LANE_KEEP], -}) - -register_env(env_name, lambda env_config: CarlaEnv(env_config)) -register_carla_model() - -ray.init() -run_experiments({ - "carla-a3c": { - "run": "A3C", - "env": "carla_env", - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "gamma": 0.8, - "num_workers": 1, - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/dqn_lane_keep.py b/python/ray/rllib/examples/carla/dqn_lane_keep.py deleted file mode 100644 index 84fed98cd5f9..000000000000 --- a/python/ray/rllib/examples/carla/dqn_lane_keep.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import register_env, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import LANE_KEEP - -env_name = "carla_env" -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": True, - "server_map": "/Game/Maps/Town02", - "reward_function": "lane_keep", - "enable_planner": False, - "scenarios": [LANE_KEEP], -}) - -register_env(env_name, lambda env_config: CarlaEnv(env_config)) -register_carla_model() - -ray.init() -run_experiments({ - "carla-dqn": { - "run": "DQN", - "env": "carla_env", - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "timesteps_per_iteration": 100, - "learning_starts": 1000, - "schedule_max_timesteps": 100000, - "gamma": 0.8, - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/ppo_lane_keep.py b/python/ray/rllib/examples/carla/ppo_lane_keep.py deleted file mode 100644 index ac0f6ff8aff0..000000000000 --- a/python/ray/rllib/examples/carla/ppo_lane_keep.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import register_env, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import LANE_KEEP - -env_name = "carla_env" -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "reward_function": "lane_keep", - "enable_planner": False, - "scenarios": [LANE_KEEP], -}) - -register_env(env_name, lambda env_config: CarlaEnv(env_config)) -register_carla_model() - -ray.init() -run_experiments({ - "carla-ppo": { - "run": "PPO", - "env": "carla_env", - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "num_workers": 1, - "timesteps_per_batch": 2000, - "min_steps_per_task": 100, - "lambda": 0.95, - "clip_param": 0.2, - "num_sgd_iter": 20, - "sgd_stepsize": 0.0001, - "sgd_batchsize": 32, - "devices": ["/gpu:0"], - "tf_session_args": { - "gpu_options": { - "allow_growth": True - } - } - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/train_a3c.py b/python/ray/rllib/examples/carla/train_a3c.py index 2c12cd8245cf..8fbcfbc576d1 100644 --- a/python/ray/rllib/examples/carla/train_a3c.py +++ b/python/ray/rllib/examples/carla/train_a3c.py @@ -3,13 +3,12 @@ from __future__ import print_function import ray -from ray.tune import grid_search, register_env, run_experiments +from ray.tune import grid_search, run_experiments from env import CarlaEnv, ENV_CONFIG from models import register_carla_model from scenarios import TOWN2_STRAIGHT -env_name = "carla_env" env_config = ENV_CONFIG.copy() env_config.update({ "verbose": False, @@ -23,7 +22,6 @@ "scenarios": TOWN2_STRAIGHT, }) -register_env(env_name, lambda env_config: CarlaEnv(env_config)) register_carla_model() redis_address = ray.services.get_node_ip_address() + ":6379" @@ -31,7 +29,7 @@ run_experiments({ "carla-a3c": { "run": "A3C", - "env": "carla_env", + "env": CarlaEnv, "config": { "env_config": env_config, "use_gpu_for_workers": True, diff --git a/python/ray/rllib/examples/carla/train_dqn.py b/python/ray/rllib/examples/carla/train_dqn.py index fa2dba1053aa..27aa65444d38 100644 --- a/python/ray/rllib/examples/carla/train_dqn.py +++ b/python/ray/rllib/examples/carla/train_dqn.py @@ -3,13 +3,12 @@ from __future__ import print_function import ray -from ray.tune import register_env, run_experiments +from ray.tune import run_experiments from env import CarlaEnv, ENV_CONFIG from models import register_carla_model from scenarios import TOWN2_ONE_CURVE -env_name = "carla_env" env_config = ENV_CONFIG.copy() env_config.update({ "verbose": False, @@ -21,7 +20,6 @@ "scenarios": TOWN2_ONE_CURVE, }) -register_env(env_name, lambda env_config: CarlaEnv(env_config)) register_carla_model() ray.init() @@ -35,7 +33,7 @@ def shape_out(spec): run_experiments({ "carla-dqn": { "run": "DQN", - "env": "carla_env", + "env": CarlaEnv, "config": { "env_config": env_config, "model": { diff --git a/python/ray/rllib/examples/carla/train_ppo.py b/python/ray/rllib/examples/carla/train_ppo.py index a9339ca79481..6c49240142c2 100644 --- a/python/ray/rllib/examples/carla/train_ppo.py +++ b/python/ray/rllib/examples/carla/train_ppo.py @@ -3,13 +3,12 @@ from __future__ import print_function import ray -from ray.tune import register_env, run_experiments +from ray.tune import run_experiments from env import CarlaEnv, ENV_CONFIG from models import register_carla_model from scenarios import TOWN2_STRAIGHT -env_name = "carla_env" env_config = ENV_CONFIG.copy() env_config.update({ "verbose": False, @@ -20,14 +19,13 @@ "server_map": "/Game/Maps/Town02", "scenarios": TOWN2_STRAIGHT, }) -register_env(env_name, lambda env_config: CarlaEnv(env_config)) register_carla_model() ray.init(redirect_output=True) run_experiments({ "carla": { "run": "PPO", - "env": "carla_env", + "env": CarlaEnv, "config": { "env_config": env_config, "model": { diff --git a/python/ray/rllib/examples/custom_env.py b/python/ray/rllib/examples/custom_env.py index 66c0288081f9..0d96eef6acb6 100644 --- a/python/ray/rllib/examples/custom_env.py +++ b/python/ray/rllib/examples/custom_env.py @@ -11,7 +11,6 @@ import ray from ray.tune import run_experiments -from ray.tune.registry import register_env class SimpleCorridor(gym.Env): @@ -42,13 +41,13 @@ def step(self, action): if __name__ == "__main__": - env_creator_name = "corridor" - register_env(env_creator_name, lambda config: SimpleCorridor(config)) + # Can also register the env creator function explicitly with: + # register_env("corridor", lambda config: SimpleCorridor(config)) ray.init() run_experiments({ "demo": { "run": "PPO", - "env": "corridor", + "env": SimpleCorridor, # or "corridor" if registered above "config": { "env_config": { "corridor_length": 5, diff --git a/python/ray/rllib/examples/custom_metrics_and_callbacks.py b/python/ray/rllib/examples/custom_metrics_and_callbacks.py index eec7bffb571f..c92ae8783748 100644 --- a/python/ray/rllib/examples/custom_metrics_and_callbacks.py +++ b/python/ray/rllib/examples/custom_metrics_and_callbacks.py @@ -35,6 +35,13 @@ def on_sample_end(info): print("returned sample batch of size {}".format(info["samples"].count)) +def on_train_result(info): + print("agent.train() result: {} -> {} episodes".format( + info["agent"], info["result"]["episodes_this_iter"])) + # you can mutate the result dict to add new fields to return + info["result"]["callback_ok"] = True + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num-iters", type=int, default=2000) @@ -54,6 +61,7 @@ def on_sample_end(info): "on_episode_step": tune.function(on_episode_step), "on_episode_end": tune.function(on_episode_end), "on_sample_end": tune.function(on_sample_end), + "on_train_result": tune.function(on_train_result), }, }, } @@ -64,3 +72,4 @@ def on_sample_end(info): print(custom_metrics) assert "mean_pole_angle" in custom_metrics assert type(custom_metrics["mean_pole_angle"]) is float + assert "callback_ok" in trials[0].last_result diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 65683eeb53c7..2f7dd175b483 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -314,8 +314,10 @@ def __repr__(self): def __str__(self): """Combines ``env`` with ``trainable_name`` and ``experiment_tag``.""" if "env" in self.config: - identifier = "{}_{}".format(self.trainable_name, - self.config["env"]) + env = self.config["env"] + if isinstance(env, type): + env = env.__name__ + identifier = "{}_{}".format(self.trainable_name, env) else: identifier = self.trainable_name if self.experiment_tag: