Skip to content

Commit

Permalink
[RLlib] RLlib deprecation Notices Part 1 (algorithm/, evaluation/, ex…
Browse files Browse the repository at this point in the history
…ecution/, models/jax/) (#36826)

Signed-off-by: Avnish <avnishnarayan@gmail.com>
  • Loading branch information
avnishn authored Jun 28, 2023
1 parent 4b7ebaa commit 684e28b
Show file tree
Hide file tree
Showing 40 changed files with 312 additions and 39 deletions.
7 changes: 7 additions & 0 deletions rllib/algorithms/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
Expand Down Expand Up @@ -147,6 +148,12 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
return super().get_rollout_fragment_length(worker_index)


@Deprecated(
old="rllib/algorithms/a2c/",
new="rllib_contrib/a2c/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class A2C(A3C):
@classmethod
@override(A3C)
Expand Down
19 changes: 7 additions & 12 deletions rllib/algorithms/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.utils.metrics import (
APPLY_GRADS_TIMER,
GRAD_WAIT_TIMER,
Expand Down Expand Up @@ -59,17 +59,6 @@ class A3CConfig(AlgorithmConfig):

def __init__(self, algo_class=None):
"""Initializes a A3CConfig instance."""
deprecation_warning(
old="rllib/algorithms/a3c/a3c.py",
new="rllib_contrib/a3c/",
help=(
"This algorithm will be "
"deprecated from RLlib in future releases. It is being moved to the "
"ray/rllib_contrib directory. See "
"https://github.com/ray-project/enhancements/blob/main/reps/2023-04-28-remove-algorithms-from-rllib.md" # noqa: E501
"for more details."
),
)
super().__init__(algo_class=algo_class or A3C)

# fmt: off
Expand Down Expand Up @@ -186,6 +175,12 @@ def validate(self) -> None:
raise ValueError("`num_workers` for A3C must be >= 1!")


@Deprecated(
old="rllib/algorithms/a3c/",
new="rllib_contrib/a3c/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class A3C(Algorithm):
@classmethod
@override(Algorithm)
Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/alpha_star/alpha_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import deep_update
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
Expand Down Expand Up @@ -242,6 +243,12 @@ def training(
return self


@Deprecated(
old="rllib/algorithms/alpha_star/",
new="rllib_contrib/alpha_star/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class AlphaStar(appo.APPO):
_allow_unknown_subkeys = appo.APPO._allow_unknown_subkeys + [
"league_builder_config",
Expand Down
13 changes: 11 additions & 2 deletions rllib/algorithms/alpha_zero/alpha_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import concat_samples
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
Deprecated,
ALGO_DEPRECATION_WARNING,
)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
Expand Down Expand Up @@ -332,6 +336,12 @@ def mcts_creator():
)


@Deprecated(
old="rllib/algorithms/alpha_star/",
new="rllib_contrib/alpha_star/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class AlphaZero(Algorithm):
@classmethod
@override(Algorithm)
Expand All @@ -352,7 +362,6 @@ def training_step(self) -> ResultDict:
Returns:
The results dict from executing the training iteration.
"""

# Sample n MultiAgentBatches from n workers.
with self._timers[SAMPLE_TIMER]:
new_sample_batches = synchronous_parallel_sample(
Expand Down
12 changes: 11 additions & 1 deletion rllib/algorithms/apex_ddpg/apex_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQN
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
Deprecated,
ALGO_DEPRECATION_WARNING,
)
from ray.rllib.utils.typing import (
ResultDict,
)
Expand Down Expand Up @@ -138,6 +142,12 @@ def training(
return self


@Deprecated(
old="rllib/algorithms/apex_ddpg/",
new="rllib_contrib/apex_ddpg/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class ApexDDPG(DDPG, ApexDQN):
@classmethod
@override(DDPG)
Expand Down
12 changes: 11 additions & 1 deletion rllib/algorithms/apex_dqn/apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
from ray.rllib.utils.actor_manager import FaultTolerantActorManager
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
Deprecated,
ALGO_DEPRECATION_WARNING,
)
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
Expand Down Expand Up @@ -310,6 +314,12 @@ def validate(self) -> None:
super().validate()


@Deprecated(
old="rllib/algorithms/apex_dqn/",
new="rllib_contrib/apex_dqn/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class ApexDQN(DQN):
@override(Trainable)
def setup(self, config: AlgorithmConfig):
Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/ars/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ray.rllib.utils import FilterManager
from ray.rllib.utils.actor_manager import FaultAwareApply
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
Expand Down Expand Up @@ -372,6 +373,12 @@ def get_policy_class(config: AlgorithmConfig):
return policy_cls


@Deprecated(
old="rllib/algorithms/ars/",
new="rllib_contrib/ars/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class ARS(Algorithm):
"""Large-scale implementation of Augmented Random Search in Ray."""

Expand Down
13 changes: 13 additions & 0 deletions rllib/algorithms/bandit/bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray.rllib.algorithms.bandit.bandit_torch_policy import BanditTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,6 +83,12 @@ def __init__(self):
# fmt: on


@Deprecated(
old="rllib/algorithms/bandit/",
new="rllib_contrib/bandit/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class BanditLinTS(Algorithm):
"""Bandit Algorithm using ThompsonSampling exploration."""

Expand All @@ -103,6 +110,12 @@ def get_default_policy_class(
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")


@Deprecated(
old="rllib/algorithms/bandit/",
new="rllib_contrib/bandit/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class BanditLinUCB(Algorithm):
@classmethod
@override(Algorithm)
Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/crr/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_TRAINED,
Expand Down Expand Up @@ -197,6 +198,12 @@ def validate(self) -> None:
NUM_GRADIENT_UPDATES = "num_grad_updates"


@Deprecated(
old="rllib/algorithms/crr/",
new="rllib_contrib/crr/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class CRR(Algorithm):

# TODO: we have a circular dependency for get
Expand Down
12 changes: 11 additions & 1 deletion rllib/algorithms/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from ray.rllib.algorithms.simple_q.simple_q import SimpleQ, SimpleQConfig
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
Deprecated,
ALGO_DEPRECATION_WARNING,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -287,6 +291,12 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
return self.rollout_fragment_length


@Deprecated(
old="rllib/algorithms/ddpg/",
new="rllib_contrib/ddpg/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class DDPG(SimpleQ):
@classmethod
@override(SimpleQ)
Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/ddppo/ddppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.utils.metrics import (
LEARN_ON_BATCH_TIMER,
NUM_AGENT_STEPS_SAMPLED,
Expand Down Expand Up @@ -226,6 +227,12 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
return self.rollout_fragment_length


@Deprecated(
old="rllib/algorithms/ddppo/",
new="rllib_contrib/ddppo/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class DDPPO(PPO):
@classmethod
@override(PPO)
Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
synchronous_parallel_sample,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
Expand Down Expand Up @@ -334,6 +335,12 @@ def postprocess_gif(self, gif: np.ndarray):
return _postprocess_gif(gif=gif)


@Deprecated(
old="rllib/algorithms/dreamer/",
new="rllib_contrib/dreamer/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class Dreamer(Algorithm):
@classmethod
@override(Algorithm)
Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/dt/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import deep_update
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
Expand Down Expand Up @@ -291,6 +292,12 @@ def validate(self) -> None:
), "replay_buffer's max_ep_len must equal rollout horizon."


@Deprecated(
old="rllib/algorithms/dt/",
new="rllib_contrib/dt/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class DT(Algorithm):
"""Implements Decision Transformer: https://arxiv.org/abs/2106.01345."""

Expand Down
8 changes: 7 additions & 1 deletion rllib/algorithms/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ray.rllib.utils import FilterManager
from ray.rllib.utils.actor_manager import FaultAwareApply
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
Expand Down Expand Up @@ -372,6 +372,12 @@ def get_policy_class(config: AlgorithmConfig):
return policy_cls


@Deprecated(
old="rllib/algorithms/es/",
new="rllib_contrib/es/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class ES(Algorithm):
"""Large-scale implementation of Evolution Strategies in Ray."""

Expand Down
12 changes: 11 additions & 1 deletion rllib/algorithms/leela_chess_zero/leela_chess_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
from ray.rllib.utils.replay_buffers import PrioritizedReplayBuffer
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
Deprecated,
ALGO_DEPRECATION_WARNING,
)
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
Expand Down Expand Up @@ -348,6 +352,12 @@ def mcts_creator():
)


@Deprecated(
old="rllib/algorithms/leela_chess_zero/",
new="rllib_contrib/leela_chess_zero/",
help=ALGO_DEPRECATION_WARNING,
error=False,
)
class LeelaChessZero(Algorithm):
@classmethod
@override(Algorithm)
Expand Down
Loading

0 comments on commit 684e28b

Please sign in to comment.