Skip to content

Commit

Permalink
[RLlib] New ConnectorV2 API #3: Introduce actual ConnectorV2 API. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Dec 21, 2023
1 parent e27ffa0 commit bd555a0
Show file tree
Hide file tree
Showing 36 changed files with 1,911 additions and 71 deletions.
17 changes: 16 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ py_test(


# --------------------------------------------------------------------
# Connector tests
# Connector(V1) tests
# rllib/connector/
#
# Tag: connector
Expand All @@ -774,6 +774,21 @@ py_test(
srcs = ["connectors/tests/test_agent.py"]
)

# --------------------------------------------------------------------
# ConnectorV2 tests
# rllib/connector/
#
# Tag: connector_v2
# --------------------------------------------------------------------

# TODO (sven): Add these tests in a separate PR.
# py_test(
# name = "connectors/tests/test_connector_v2",
# tags = ["team:rllib", "connector_v2"],
# size = "small",
# srcs = ["connectors/tests/test_connector_v2.py"]
# )

# --------------------------------------------------------------------
# Env tests
# rllib/env/
Expand Down
18 changes: 9 additions & 9 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,11 @@ def setup(self, config: AlgorithmConfig) -> None:
config_obj.env = self._env_id
self.config = config_obj

self._uses_new_env_runners = (
self.config.env_runner_cls is not None
and not issubclass(self.config.env_runner_cls, RolloutWorker)
)

# Set Algorithm's seed after we have - if necessary - enabled
# tf eager-execution.
update_global_seed_if_necessary(self.config.framework_str, self.config.seed)
Expand Down Expand Up @@ -751,13 +756,12 @@ def setup(self, config: AlgorithmConfig) -> None:
)

# Only when using RolloutWorkers: Update also the worker set's
# `should_module_be_updated_fn` (analogous to is_policy_to_train).
# `is_policy_to_train` (analogous to LearnerGroup's
# `should_module_be_updated_fn`).
# Note that with the new EnvRunner API in combination with the new stack,
# this information only needs to be kept in the LearnerGroup and not on the
# EnvRunners anymore.
if self.config.env_runner_cls is None or issubclass(
self.config.env_runner_cls, RolloutWorker
):
if not self._uses_new_env_runners:
update_fn = self.learner_group.should_module_be_updated_fn
self.workers.foreach_worker(
lambda w: w.set_is_policy_to_train(update_fn),
Expand Down Expand Up @@ -3030,11 +3034,7 @@ def _run_one_evaluation(
"""
eval_func_to_use = (
self._evaluate_async_with_env_runner
if (
self.config.enable_async_evaluation
and self.config.env_runner_cls is not None
and not issubclass(self.config.env_runner_cls, RolloutWorker)
)
if (self.config.enable_async_evaluation and self._uses_new_env_runners)
else self._evaluate_async
if self.config.enable_async_evaluation
else self.evaluate
Expand Down
143 changes: 143 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@

if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.learner import Learner
from ray.rllib.core.learner.learner_group import LearnerGroup
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.evaluation.episode import Episode as OldEpisode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -327,6 +329,8 @@ def __init__(self, algo_class=None):
self.num_envs_per_worker = 1
self.create_env_on_local_worker = False
self.enable_connectors = True
self._env_to_module_connector = None
self._module_to_env_connector = None
# TODO (sven): Rename into `sample_timesteps` (or `sample_duration`
# and `sample_duration_unit` (replacing batch_mode), like we do it
# in the evaluation config).
Expand Down Expand Up @@ -374,6 +378,7 @@ def __init__(self, algo_class=None):
except AttributeError:
pass

self._learner_connector = None
self.optimizer = {}
self.max_requests_in_flight_per_sampler_worker = 2
self._learner_class = None
Expand Down Expand Up @@ -1152,6 +1157,121 @@ class directly. Note that this arg can also be specified via
logger_creator=self.logger_creator,
)

def build_env_to_module_connector(self, env):
from ray.rllib.connectors.env_to_module import (
EnvToModulePipeline,
DefaultEnvToModule,
)

custom_connectors = []
# Create an env-to-module connector pipeline (including RLlib's default
# env->module connector piece) and return it.
if self._env_to_module_connector is not None:
val_ = self._env_to_module_connector(env)

from ray.rllib.connectors.connector_v2 import ConnectorV2

if isinstance(val_, ConnectorV2) and not isinstance(
val_, EnvToModulePipeline
):
custom_connectors = [val_]
elif isinstance(val_, (list, tuple)):
custom_connectors = list(val_)
else:
return val_

pipeline = EnvToModulePipeline(
connectors=custom_connectors,
input_observation_space=env.single_observation_space,
input_action_space=env.single_action_space,
env=env,
)
pipeline.append(
DefaultEnvToModule(
input_observation_space=pipeline.observation_space,
input_action_space=pipeline.action_space,
env=env,
)
)
return pipeline

def build_module_to_env_connector(self, env):

from ray.rllib.connectors.module_to_env import (
DefaultModuleToEnv,
ModuleToEnvPipeline,
)

custom_connectors = []
# Create a module-to-env connector pipeline (including RLlib's default
# module->env connector piece) and return it.
if self._module_to_env_connector is not None:
val_ = self._module_to_env_connector(env)

from ray.rllib.connectors.connector_v2 import ConnectorV2

if isinstance(val_, ConnectorV2) and not isinstance(
val_, ModuleToEnvPipeline
):
custom_connectors = [val_]
elif isinstance(val_, (list, tuple)):
custom_connectors = list(val_)
else:
return val_

pipeline = ModuleToEnvPipeline(
connectors=custom_connectors,
input_observation_space=env.single_observation_space,
input_action_space=env.single_action_space,
env=env,
)
pipeline.append(
DefaultModuleToEnv(
input_observation_space=pipeline.observation_space,
input_action_space=pipeline.action_space,
env=env,
normalize_actions=self.normalize_actions,
clip_actions=self.clip_actions,
)
)
return pipeline

def build_learner_connector(self, input_observation_space, input_action_space):
from ray.rllib.connectors.learner import (
DefaultLearnerConnector,
LearnerConnectorPipeline,
)

custom_connectors = []
# Create a learner connector pipeline (including RLlib's default
# learner connector piece) and return it.
if self._learner_connector is not None:
val_ = self._learner_connector(input_observation_space, input_action_space)

from ray.rllib.connectors.connector_v2 import ConnectorV2

if isinstance(val_, ConnectorV2) and not isinstance(
val_, LearnerConnectorPipeline
):
custom_connectors = [val_]
elif isinstance(val_, (list, tuple)):
custom_connectors = list(val_)
else:
return val_

pipeline = LearnerConnectorPipeline(
connectors=custom_connectors,
input_observation_space=input_observation_space,
input_action_space=input_action_space,
)
pipeline.append(
DefaultLearnerConnector(
input_observation_space=pipeline.observation_space,
input_action_space=pipeline.action_space,
)
)
return pipeline

def build_learner_group(
self,
*,
Expand Down Expand Up @@ -1605,6 +1725,12 @@ def rollouts(
create_env_on_local_worker: Optional[bool] = NotProvided,
sample_collector: Optional[Type[SampleCollector]] = NotProvided,
enable_connectors: Optional[bool] = NotProvided,
env_to_module_connector: Optional[
Callable[[EnvType], "ConnectorV2"]
] = NotProvided,
module_to_env_connector: Optional[
Callable[[EnvType, "RLModule"], "ConnectorV2"]
] = NotProvided,
use_worker_filter_stats: Optional[bool] = NotProvided,
update_worker_filter_stats: Optional[bool] = NotProvided,
rollout_fragment_length: Optional[Union[int, str]] = NotProvided,
Expand Down Expand Up @@ -1650,6 +1776,11 @@ def rollouts(
enable_connectors: Use connector based environment runner, so that all
preprocessing of obs and postprocessing of actions are done in agent
and action connectors.
env_to_module_connector: A callable taking an Env as input arg and returning
an env-to-module ConnectorV2 (might be a pipeline) object.
module_to_env_connector: A callable taking an Env and an RLModule as input
args and returning a module-to-env ConnectorV2 (might be a pipeline)
object.
use_worker_filter_stats: Whether to use the workers in the WorkerSet to
update the central filters (held by the local worker). If False, stats
from the workers will not be used and discarded.
Expand Down Expand Up @@ -1737,6 +1868,10 @@ def rollouts(
self.create_env_on_local_worker = create_env_on_local_worker
if enable_connectors is not NotProvided:
self.enable_connectors = enable_connectors
if env_to_module_connector is not NotProvided:
self._env_to_module_connector = env_to_module_connector
if module_to_env_connector is not NotProvided:
self._module_to_env_connector = module_to_env_connector
if use_worker_filter_stats is not NotProvided:
self.use_worker_filter_stats = use_worker_filter_stats
if update_worker_filter_stats is not NotProvided:
Expand Down Expand Up @@ -1855,6 +1990,9 @@ def training(
optimizer: Optional[dict] = NotProvided,
max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided,
learner_class: Optional[Type["Learner"]] = NotProvided,
learner_connector: Optional[
Callable[["RLModule"], "ConnectorV2"]
] = NotProvided,
# Deprecated arg.
_enable_learner_api: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
Expand Down Expand Up @@ -1916,6 +2054,9 @@ def training(
in your experiment of timesteps.
learner_class: The `Learner` class to use for (distributed) updating of the
RLModule. Only used when `_enable_new_api_stack=True`.
learner_connector: A callable taking an env observation space and an env
action space as inputs and returning a learner ConnectorV2 (might be
a pipeline) object.
Returns:
This updated AlgorithmConfig object.
Expand Down Expand Up @@ -1960,6 +2101,8 @@ def training(
)
if learner_class is not NotProvided:
self._learner_class = learner_class
if learner_connector is not NotProvided:
self._learner_connector = learner_connector

return self

Expand Down
7 changes: 3 additions & 4 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,17 @@ class ImpalaConfig(AlgorithmConfig):
# Update the config object.
config = config.training(
lr=tune.grid_search([0.0001, ]), grad_clip=20.0
lr=tune.grid_search([0.0001, 0.0002]), grad_clip=20.0
)
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
# Set the config object's env.
config = config.environment(env="CartPole-v1")
# Use to_dict() to get the old-style python config dict
# when running with tune.
# Run with tune.
tune.Tuner(
"IMPALA",
param_space=config,
run_config=air.RunConfig(stop={"training_iteration": 1}),
param_space=config.to_dict(),
).fit()
.. testoutput::
Expand Down
5 changes: 2 additions & 3 deletions rllib/algorithms/pg/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ class PGConfig(AlgorithmConfig):
>>> config = config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.
>>> config = config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> # Run with tune.
>>> tune.Tuner( # doctest: +SKIP
... "PG",
... run_config=air.RunConfig(stop={"episode_reward_mean": 200}),
... param_space=config.to_dict(),
... param_space=config,
... ).fit()
"""

Expand Down
20 changes: 10 additions & 10 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,10 @@ def training(
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)

# TODO (sven): Move to generic AlgorithmConfig.
if lr_schedule is not NotProvided:
self.lr_schedule = lr_schedule
if use_critic is not NotProvided:
self.use_critic = use_critic
# TODO (Kourosh) This is experimental. Set learner_hps parameters as
# well. Don't forget to remove .use_critic from algorithm config.
# TODO (Kourosh) This is experimental.
# Don't forget to remove .use_critic from algorithm config.
if use_gae is not NotProvided:
self.use_gae = use_gae
if lambda_ is not NotProvided:
Expand All @@ -280,15 +277,19 @@ def training(
self.vf_loss_coeff = vf_loss_coeff
if entropy_coeff is not NotProvided:
self.entropy_coeff = entropy_coeff
if entropy_coeff_schedule is not NotProvided:
self.entropy_coeff_schedule = entropy_coeff_schedule
if clip_param is not NotProvided:
self.clip_param = clip_param
if vf_clip_param is not NotProvided:
self.vf_clip_param = vf_clip_param
if grad_clip is not NotProvided:
self.grad_clip = grad_clip

# TODO (sven): Remove these once new API stack is only option for PPO.
if lr_schedule is not NotProvided:
self.lr_schedule = lr_schedule
if entropy_coeff_schedule is not NotProvided:
self.entropy_coeff_schedule = entropy_coeff_schedule

return self

@override(AlgorithmConfig)
Expand All @@ -312,8 +313,8 @@ def validate(self) -> None:
raise ValueError(
f"`sgd_minibatch_size` ({self.sgd_minibatch_size}) must be <= "
f"`train_batch_size` ({self.train_batch_size}). In PPO, the train batch"
f" is be split into {self.sgd_minibatch_size} chunks, each of which is "
f"iterated over (used for updating the policy) {self.num_sgd_iter} "
f" will be split into {self.sgd_minibatch_size} chunks, each of which "
f"is iterated over (used for updating the policy) {self.num_sgd_iter} "
"times."
)

Expand Down Expand Up @@ -476,7 +477,6 @@ def training_step(self) -> ResultDict:
self.workers.local_worker().set_weights(weights)

if self.config._enable_new_api_stack:

kl_dict = {}
if self.config.use_kl_loss:
for pid in policies_to_update:
Expand Down
Loading

0 comments on commit bd555a0

Please sign in to comment.