Skip to content

Commit

Permalink
[RLlib] Replace remaining mentions of "trainer" by "algorithm". (#36557)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Jun 21, 2023
1 parent 7ed5c6d commit 827ab91
Show file tree
Hide file tree
Showing 79 changed files with 219 additions and 360 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3703,7 +3703,7 @@ py_test(
)

# Taking out this test for now: Mixed torch- and tf- policies within the same
# Trainer never really worked.
# Algorothm never really worked.
# py_test(
# name = "examples/multi_agent_two_trainers_mixed_torch_tf",
# main = "examples/multi_agent_two_trainers.py",
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/a3c/tests/test_a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def test_a3c_entropy_coeff_schedule(self):
min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=20
)

def _step_n_times(trainer, n: int):
"""Step trainer n times.
def _step_n_times(algo, n: int):
"""Step Algorithm n times.
Returns:
learning rate at the end of the execution.
"""
for _ in range(n):
results = trainer.train()
results = algo.train()
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
"entropy_coeff"
]
Expand Down
120 changes: 27 additions & 93 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def __init__(
# Last resort: Create core AlgorithmConfig from merged dicts.
if isinstance(default_config, dict):
config = AlgorithmConfig.from_dict(
config_dict=self.merge_trainer_configs(default_config, config, True)
config_dict=self.merge_algorithm_configs(
default_config, config, True
)
)
# Default config is an AlgorithmConfig -> update its properties
# from the given config dict.
Expand Down Expand Up @@ -569,17 +571,17 @@ def setup(self, config: AlgorithmConfig) -> None:
)
self.config.off_policy_estimation_methods = ope_dict

# Deprecated way of implementing Trainer sub-classes (or "templates"
# Deprecated way of implementing Algorithm sub-classes (or "templates"
# via the `build_trainer` utility function).
# Instead, sub-classes should override the Trainable's `setup()`
# method and call super().setup() from within that override at some
# point.
# Old design: Override `Trainer._init`.
# Old design: Override `Algorithm._init`.
_init = False
try:
self._init(self.config, self.env_creator)
_init = True
# New design: Override `Trainable.setup()` (as indented by tune.Trainable)
# New design: Override `Algorithm.setup()` (as indented by tune.Trainable)
# and do or don't call `super().setup()` from within your override.
# By default, `super().setup()` will create both worker sets:
# "rollout workers" for collecting samples for training and - if
Expand Down Expand Up @@ -743,7 +745,7 @@ def setup(self, config: AlgorithmConfig) -> None:
# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self)

# TODO: Deprecated: In your sub-classes of Trainer, override `setup()`
# TODO: Deprecated: In your sub-classes of Algorithm, override `setup()`
# directly and call super().setup() from within it if you would like the
# default setup behavior plus some own setup logic.
# If you don't need the env/workers/config/etc.. setup for you by super,
Expand All @@ -767,13 +769,13 @@ def get_default_policy_class(

@override(Trainable)
def step(self) -> ResultDict:
"""Implements the main `Trainer.train()` logic.
"""Implements the main `Algorithm.train()` logic.
Takes n attempts to perform a single training step. Thereby
catches RayErrors resulting from worker failures. After n attempts,
fails gracefully.
Override this method in your Trainer sub-classes if you would like to
Override this method in your Algorithm sub-classes if you would like to
handle worker failures yourself.
Otherwise, override only `training_step()` to implement the core
algorithm logic.
Expand Down Expand Up @@ -815,7 +817,7 @@ def step(self) -> ResultDict:
if not evaluate_this_iter and self.config.always_attach_evaluation_results:
assert isinstance(
self.evaluation_metrics, dict
), "Trainer.evaluate() needs to return a dict."
), "Algorithm.evaluate() needs to return a dict."
results.update(self.evaluation_metrics)

if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
Expand Down Expand Up @@ -865,9 +867,6 @@ def evaluate(
) -> dict:
"""Evaluates current policy under `evaluation_config` settings.
Note that this default implementation does not do anything beyond
merging evaluation_config with the normal trainer config.
Args:
duration_fn: An optional callable taking the already run
num episodes as only arg and returning the number of
Expand Down Expand Up @@ -914,7 +913,7 @@ def evaluate(
):
raise ValueError(
"Cannot evaluate w/o an evaluation worker set in "
"the Trainer or w/o an env on the local worker!\n"
"the Algorithm or w/o an env on the local worker!\n"
"Try one of the following:\n1) Set "
"`evaluation_interval` >= 0 to force creating a "
"separate evaluation worker set.\n2) Set "
Expand Down Expand Up @@ -1105,7 +1104,7 @@ def duration_fn(num_units_done):
metrics["off_policy_estimator"][name] = avg_estimate

# Evaluation does not run for every step.
# Save evaluation metrics on trainer, so it can be attached to
# Save evaluation metrics on Algorithm, so it can be attached to
# subsequent step results as latest evaluation result.
self.evaluation_metrics = {"evaluation": metrics}

Expand Down Expand Up @@ -1298,7 +1297,7 @@ def remote_fn(worker):
metrics["off_policy_estimator"][name] = estimates

# Evaluation does not run for every step.
# Save evaluation metrics on trainer, so it can be attached to
# Save evaluation metrics on Algorithm, so it can be attached to
# subsequent step results as latest evaluation result.
self.evaluation_metrics = {"evaluation": metrics}

Expand Down Expand Up @@ -1360,7 +1359,7 @@ def training_step(self) -> ResultDict:
"""Default single iteration logic of an algorithm.
- Collect on-policy samples (SampleBatches) in parallel using the
Trainer's RolloutWorkers (@ray.remote).
Algorithm's RolloutWorkers (@ray.remote).
- Concatenate collected SampleBatches into one train batch.
- Note that we may have more than one policy in the multi-agent case:
Call the different policies' `learn_on_batch` (simple optimizer) OR
Expand Down Expand Up @@ -1431,10 +1430,10 @@ def training_step(self) -> ResultDict:
@staticmethod
def execution_plan(workers, config, **kwargs):
raise NotImplementedError(
"It is not longer recommended to use Trainer's `execution_plan` method/API."
"It is no longer supported to use the `Algorithm.execution_plan()` API!"
" Set `_disable_execution_plan_api=True` in your config and override the "
"`Trainer.training_step()` method with your algo's custom "
"execution logic."
"`Algorithm.training_step()` method with your algo's custom "
"execution logic instead."
)

@PublicAPI
Expand All @@ -1454,9 +1453,6 @@ def compute_single_action(
episode: Optional[Episode] = None,
unsquash_action: Optional[bool] = None,
clip_action: Optional[bool] = None,
# Deprecated args.
unsquash_actions=DEPRECATED_VALUE,
clip_actions=DEPRECATED_VALUE,
# Kwargs placeholder for future compatibility.
**kwargs,
) -> Union[
Expand Down Expand Up @@ -1506,24 +1502,9 @@ def compute_single_action(
or we have an RNN-based Policy.
Raises:
KeyError: If the `policy_id` cannot be found in this Trainer's
local worker.
KeyError: If the `policy_id` cannot be found in this Algorithm's local
worker.
"""
if clip_actions != DEPRECATED_VALUE:
deprecation_warning(
old="Trainer.compute_single_action(`clip_actions`=...)",
new="Trainer.compute_single_action(`clip_action`=...)",
error=True,
)
clip_action = clip_actions
if unsquash_actions != DEPRECATED_VALUE:
deprecation_warning(
old="Trainer.compute_single_action(`unsquash_actions`=...)",
new="Trainer.compute_single_action(`unsquash_action`=...)",
error=True,
)
unsquash_action = unsquash_actions

# `unsquash_action` is None: Use value of config['normalize_actions'].
if unsquash_action is None:
unsquash_action = self.config.normalize_actions
Expand All @@ -1535,7 +1516,7 @@ def compute_single_action(
# are all None.
err_msg = (
"Provide either `input_dict` OR [`observation`, ...] as "
"args to Trainer.compute_single_action!"
"args to `Algorithm.compute_single_action()`!"
)
if input_dict is not None:
assert (
Expand All @@ -1549,12 +1530,12 @@ def compute_single_action(
assert observation is not None, err_msg

# Get the policy to compute the action for (in the multi-agent case,
# Trainer may hold >1 policies).
# Algorithm may hold >1 policies).
policy = self.get_policy(policy_id)
if policy is None:
raise KeyError(
f"PolicyID '{policy_id}' not found in PolicyMap of the "
f"Trainer's local worker!"
f"Algorithm's local worker!"
)
local_worker = self.workers.local_worker()

Expand Down Expand Up @@ -1657,8 +1638,6 @@ def compute_actions(
episodes: Optional[List[Episode]] = None,
unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
# Deprecated.
normalize_actions=None,
**kwargs,
):
"""Computes an action for the specified policy on the local Worker.
Expand Down Expand Up @@ -1700,14 +1679,6 @@ def compute_actions(
the full output of policy.compute_actions_from_input_dict() if
full_fetch=True or we have an RNN-based Policy.
"""
if normalize_actions is not None:
deprecation_warning(
old="Trainer.compute_actions(`normalize_actions`=...)",
new="Trainer.compute_actions(`unsquash_actions`=...)",
error=True,
)
unsquash_actions = normalize_actions

# `unsquash_actions` is None: Use value of config['normalize_actions'].
if unsquash_actions is None:
unsquash_actions = self.config.normalize_actions
Expand Down Expand Up @@ -1834,8 +1805,6 @@ def add_policy(
] = None,
evaluation_workers: bool = True,
module_spec: Optional[SingleAgentRLModuleSpec] = None,
# Deprecated.
workers: Optional[List[Union[RolloutWorker, ActorHandle]]] = DEPRECATED_VALUE,
) -> Optional[Policy]:
"""Adds a new policy to this Algorithm.
Expand Down Expand Up @@ -1873,27 +1842,13 @@ def add_policy(
module_spec: In the new RLModule API we need to pass in the module_spec for
the new module that is supposed to be added. Knowing the policy spec is
not sufficient.
workers: A list of RolloutWorker/ActorHandles (remote
RolloutWorkers) to add this policy to. If defined, will only
add the given policy to these workers.
Returns:
The newly added policy (the copy that got added to the local
worker). If `workers` was provided, None is returned.
"""
validate_policy_id(policy_id, error=True)

if workers is not DEPRECATED_VALUE:
deprecation_warning(
old="Algorithm.add_policy(.., workers=..)",
help=(
"The `workers` argument to `Algorithm.add_policy()` is deprecated! "
"Please do not use it anymore."
),
error=True,
)

self.workers.add_policy(
policy_id,
policy_cls,
Expand Down Expand Up @@ -2016,7 +1971,6 @@ def export_policy_model(
def export_policy_checkpoint(
self,
export_dir: str,
filename_prefix=DEPRECATED_VALUE, # deprecated arg, do not use anymore
policy_id: PolicyID = DEFAULT_POLICY_ID,
) -> None:
"""Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.
Expand All @@ -2039,14 +1993,6 @@ def export_policy_checkpoint(
>>> algo.train() # doctest: +SKIP
>>> algo.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP
"""
# `filename_prefix` should not longer be used as new Policy checkpoints
# contain more than one file with a fixed filename structure.
if filename_prefix != DEPRECATED_VALUE:
deprecation_warning(
old="Algorithm.export_policy_checkpoint(filename_prefix=...)",
error=True,
)

policy = self.get_policy(policy_id)
if policy is None:
raise KeyError(f"Policy with ID {policy_id} not found in Algorithm!")
Expand Down Expand Up @@ -2173,7 +2119,8 @@ def load_checkpoint(self, checkpoint: str) -> None:
def log_result(self, result: ResultDict) -> None:
# Log after the callback is invoked, so that the user has a chance
# to mutate the result.
# TODO: Remove `trainer` arg at some point to fully deprecate the old signature.
# TODO: Remove `algorithm` arg at some point to fully deprecate the old
# signature.
self.callbacks.on_train_result(algorithm=self, result=result)
# Then log according to Trainable's logging logic.
Trainable.log_result(self, result)
Expand Down Expand Up @@ -2477,7 +2424,7 @@ def get_auto_filled_metrics(
return auto_filled

@classmethod
def merge_trainer_configs(
def merge_algorithm_configs(
cls,
config1: AlgorithmConfigDict,
config2: PartialAlgorithmConfigDict,
Expand Down Expand Up @@ -2754,7 +2701,7 @@ def _checkpoint_info_to_algorithm_state(
if isinstance(default_config, AlgorithmConfig):
new_config = default_config.update_from_dict(state["config"])
else:
new_config = Algorithm.merge_trainer_configs(
new_config = Algorithm.merge_algorithm_configs(
default_config, state["config"]
)

Expand Down Expand Up @@ -3146,21 +3093,8 @@ def _record_usage(self, config):
alg = "USER_DEFINED"
record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg)

@Deprecated(new="Algorithm.compute_single_action()", error=True)
def compute_action(self, *args, **kwargs):
return self.compute_single_action(*args, **kwargs)

@Deprecated(new="construct WorkerSet(...) instance directly", error=True)
def _make_workers(self, *args, **kwargs):
pass

@Deprecated(new="AlgorithmConfig.validate()", error=False)
def validate_config(self, config):
pass

@staticmethod
@Deprecated(new="AlgorithmConfig.validate()", error=True)
def _validate_config(config, trainer_or_none):
def validate_config(self, config):
pass


Expand Down
10 changes: 5 additions & 5 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class AlgorithmConfig(_Config):
... .resources(num_gpus=0)
... .rollouts(num_rollout_workers=4)
... .callbacks(MemoryTrackingCallbacks)
>>> # A config object can be used to construct the respective Trainer.
>>> # A config object can be used to construct the respective Algorithm.
>>> rllib_algo = config.build() # doctest: +SKIP
Example:
Expand All @@ -139,7 +139,7 @@ class AlgorithmConfig(_Config):
>>> # Use `to_dict()` method to get the legacy plain python config dict
>>> # for usage with `tune.Tuner().fit()`.
>>> tune.Tuner( # doctest: +SKIP
... "[registered trainer class]", param_space=config.to_dict()
... "[registered Algorithm class]", param_space=config.to_dict()
... ).fit()
"""

Expand Down Expand Up @@ -234,7 +234,7 @@ def overrides(cls, **kwargs):
def __init__(self, algo_class=None):
# Define all settings and their default values.

# Define the default RLlib Trainer class that this AlgorithmConfig will be
# Define the default RLlib Algorithm class that this AlgorithmConfig will be
# applied to.
self.algo_class = algo_class

Expand Down Expand Up @@ -1125,7 +1125,7 @@ def resources(
`num_gpus_per_learner_worker` accordingly (e.g. 4 GPUs total, and model
needs 2 GPUs: `num_learner_workers = 2` and
`num_gpus_per_learner_worker = 2`)
num_cpus_per_learner_worker: Number of CPUs allocated per trainer worker.
num_cpus_per_learner_worker: Number of CPUs allocated per Learner worker.
Only necessary for custom processing pipeline inside each Learner
requiring multiple CPU cores. Ignored if `num_learner_workers = 0`.
num_gpus_per_learner_worker: Number of GPUs allocated per worker. If
Expand Down Expand Up @@ -3095,7 +3095,7 @@ def get_default_learner_class(self) -> Union[Type["Learner"], str]:
Returns:
The Learner class to use for this algorithm either as a class type or as
a string (e.g. ray.rllib.core.learner.testing.torch.BCTrainer).
a string (e.g. ray.rllib.core.learner.testing.torch.BC).
"""
raise NotImplementedError

Expand Down
Loading

0 comments on commit 827ab91

Please sign in to comment.