Skip to content

Commit

Permalink
[RLlib] New API stack: On by default for BC/MARWIL/CQL. (#48599)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Nov 10, 2024
1 parent aee0a0e commit 03ea4f6
Show file tree
Hide file tree
Showing 20 changed files with 135 additions and 153 deletions.
5 changes: 3 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -948,12 +948,13 @@ py_test(
)

# CQL
# @OldAPIStack
py_test(
name = "test_cql",
name = "test_cql_old_api_stack",
tags = ["team:rllib", "algorithms_dir"],
size = "large",
data = ["tests/data/pendulum/small.json"],
srcs = ["algorithms/cql/tests/test_cql.py"]
srcs = ["algorithms/cql/tests/test_cql_old_api_stack.py"]
)

# DQN
Expand Down
61 changes: 24 additions & 37 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,6 @@

Space = gym.Space

"""TODO(jungong, sven): in "offline_data" we can potentially unify all input types
under input and input_config keys. E.g.
input: sample
input_config {
env: CartPole-v1
}
or:
input: json_reader
input_config {
path: /tmp/
}
or:
input: dataset
input_config {
format: parquet
path: /tmp/
}
"""


if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
Expand Down Expand Up @@ -131,12 +112,13 @@ class AlgorithmConfig(_Config):
from ray.rllib.algorithms.callbacks import MemoryTrackingCallbacks
# Construct a generic config object, specifying values within different
# sub-categories, e.g. "training".
config = (PPOConfig().training(gamma=0.9, lr=0.01)
.environment(env="CartPole-v1")
.resources(num_gpus=0)
.env_runners(num_env_runners=0)
.callbacks(MemoryTrackingCallbacks)
)
config = (
PPOConfig()
.training(gamma=0.9, lr=0.01)
.environment(env="CartPole-v1")
.env_runners(num_env_runners=0)
.callbacks(MemoryTrackingCallbacks)
)
# A config object can be used to construct the respective Algorithm.
rllib_algo = config.build()
Expand Down Expand Up @@ -321,10 +303,6 @@ def __init__(self, algo_class: Optional[type] = None):
# Default setting for skipping `nan` gradient updates.
self.torch_skip_nan_gradients = False

# `self.api_stack()`
self.enable_rl_module_and_learner = False
self.enable_env_runner_and_connector_v2 = False

# `self.environment()`
self.env = None
self.env_config = {}
Expand Down Expand Up @@ -425,7 +403,19 @@ def __init__(self, algo_class: Optional[type] = None):
self.explore = True
# This is not compatible with RLModules, which have a method
# `forward_exploration` to specify custom exploration behavior.
self.exploration_config = {}
if not hasattr(self, "exploration_config"):
# Helper to keep track of the original exploration config when dis-/enabling
# rl modules.
self._prior_exploration_config = None
self.exploration_config = {}

# `self.api_stack()`
self.enable_rl_module_and_learner = True
self.enable_env_runner_and_connector_v2 = True
self.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)

# `self.multi_agent()`
# TODO (sven): Prepare multi-agent setup for logging each agent's and each
Expand Down Expand Up @@ -549,9 +539,6 @@ def __init__(self, algo_class: Optional[type] = None):
# `self.rl_module()`
self._model_config = {}
self._rl_module_spec = None
# Helper to keep track of the original exploration config when dis-/enabling
# rl modules.
self.__prior_exploration_config = None
# Module ID specific config overrides.
self.algorithm_config_overrides_per_module = {}
# Cached, actual AlgorithmConfig objects derived from
Expand Down Expand Up @@ -1612,13 +1599,13 @@ def api_stack(
self.enable_rl_module_and_learner = enable_rl_module_and_learner

if enable_rl_module_and_learner is True and self.exploration_config:
self.__prior_exploration_config = self.exploration_config
self._prior_exploration_config = self.exploration_config
self.exploration_config = {}

elif enable_rl_module_and_learner is False and not self.exploration_config:
if self.__prior_exploration_config is not None:
self.exploration_config = self.__prior_exploration_config
self.__prior_exploration_config = None
if self._prior_exploration_config is not None:
self.exploration_config = self._prior_exploration_config
self._prior_exploration_config = None
else:
logger.warning(
"config.enable_rl_module_and_learner was set to False, but no "
Expand Down
8 changes: 2 additions & 6 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ class APPOConfig(IMPALAConfig):

def __init__(self, algo_class=None):
"""Initializes a APPOConfig instance."""
super().__init__(algo_class=algo_class or APPO)

self.exploration_config = {
# The Exploration class to use. In the simplest case, this is the name
# (str) of any class present in the `rllib.utils.exploration` package.
Expand All @@ -100,6 +98,8 @@ def __init__(self, algo_class=None):
# Add constructor kwargs here (if any).
}

super().__init__(algo_class=algo_class or APPO)

# fmt: off
# __sphinx_doc_begin__
# APPO specific settings:
Expand Down Expand Up @@ -138,10 +138,6 @@ def __init__(self, algo_class=None):
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.01
self.tau = 1.0
self.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# __sphinx_doc_end__
# fmt: on

Expand Down
4 changes: 1 addition & 3 deletions rllib/algorithms/appo/tests/test_appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import ray.rllib.algorithms.appo as appo
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.metrics import (
LEARNER_RESULTS,
)
from ray.rllib.utils.metrics import LEARNER_RESULTS
from ray.rllib.utils.test_utils import (
check_train_results,
check_train_results_new_api_stack,
Expand Down
4 changes: 4 additions & 0 deletions rllib/algorithms/bc/tests/test_bc_old_api_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def test_bc_compilation_and_learning_from_offline_file(self):

config = (
bc.BCConfig()
.api_stack(
enable_env_runner_and_connector_v2=False,
enable_rl_module_and_learner=False,
)
.evaluation(
evaluation_interval=3,
evaluation_num_env_runners=1,
Expand Down
10 changes: 0 additions & 10 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,9 @@ def __init__(self, algo_class=None):

# Changes to Algorithm's/SACConfig's default:

# `.api_stack()`
self.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
# .reporting()
self.min_sample_timesteps_per_iteration = 0
self.min_train_timesteps_per_iteration = 100
# `.api_stack()`
self.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
# fmt: on
# __sphinx_doc_end__

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def test_cql_compilation(self):

config = (
cql.CQLConfig()
.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
.environment(
env="Pendulum-v1",
)
Expand Down
21 changes: 9 additions & 12 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,19 @@ class DQNConfig(AlgorithmConfig):

def __init__(self, algo_class=None):
"""Initializes a DQNConfig instance."""
super().__init__(algo_class=algo_class or DQN)

# Overrides of AlgorithmConfig defaults
# `env_runners()`
# Set to `self.n_step`, if 'auto'.
self.rollout_fragment_length: Union[int, str] = "auto"
self.exploration_config = {
"type": "EpsilonGreedy",
"initial_epsilon": 1.0,
"final_epsilon": 0.02,
"epsilon_timesteps": 10000,
}

super().__init__(algo_class=algo_class or DQN)

# Overrides of AlgorithmConfig defaults
# `env_runners()`
# Set to `self.n_step`, if 'auto'.
self.rollout_fragment_length: Union[int, str] = "auto"
# New stack uses `epsilon` as either a constant value or a scheduler
# defined like this.
# TODO (simon): Ensure that users can understand how to provide epsilon.
Expand Down Expand Up @@ -174,7 +175,6 @@ def __init__(self, algo_class=None):
self.target_network_update_freq = 500
self.num_steps_sampled_before_learning_starts = 1000
self.store_buffer_in_checkpoints = False
self.lr_schedule = None
self.adam_epsilon = 1e-8

self.tau = 1.0
Expand Down Expand Up @@ -203,14 +203,11 @@ def __init__(self, algo_class=None):
# Beta parameter for sampling from prioritized replay buffer.
"beta": 0.4,
}
# `.api_stack()`
self.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# fmt: on
# __sphinx_doc_end__

self.lr_schedule = None # @OldAPIStack

# Deprecated
self.buffer_size = DEPRECATED_VALUE
self.prioritized_replay = DEPRECATED_VALUE
Expand Down
8 changes: 2 additions & 6 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ class IMPALAConfig(AlgorithmConfig):

def __init__(self, algo_class=None):
"""Initializes a IMPALAConfig instance."""
super().__init__(algo_class=algo_class or IMPALA)

self.exploration_config = { # @OldAPIstack
# The Exploration class to use. In the simplest case, this is the name
# (str) of any class present in the `rllib.utils.exploration` package.
Expand All @@ -135,6 +133,8 @@ def __init__(self, algo_class=None):
# Add constructor kwargs here (if any).
}

super().__init__(algo_class=algo_class or IMPALA)

# fmt: off
# __sphinx_doc_begin__

Expand Down Expand Up @@ -170,10 +170,6 @@ def __init__(self, algo_class=None):
self.num_env_runners = 2
self.lr = 0.0005
self.min_time_s_per_iteration = 10
self.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# __sphinx_doc_end__
# fmt: on

Expand Down
22 changes: 10 additions & 12 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ class MARWILConfig(AlgorithmConfig):

def __init__(self, algo_class=None):
"""Initializes a MARWILConfig instance."""
self.exploration_config = {
# The Exploration class to use. In the simplest case, this is the name
# (str) of any class present in the `rllib.utils.exploration` package.
# You can also provide the python class directly or the full location
# of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
# EpsilonGreedy").
"type": "StochasticSampling",
# Add constructor kwargs here (if any).
}

super().__init__(algo_class=algo_class or MARWIL)

# fmt: off
Expand Down Expand Up @@ -165,18 +175,6 @@ def __init__(self, algo_class=None):
self.lr = 1e-4
self.lambda_ = 1.0
self.train_batch_size = 2000
# TODO (Artur): MARWIL should not need an exploration config as an offline
# algorithm. However, the current implementation of the CRR algorithm
# requires it. Investigate.
self.exploration_config = {
# The Exploration class to use. In the simplest case, this is the name
# (str) of any class present in the `rllib.utils.exploration` package.
# You can also provide the python class directly or the full location
# of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
# EpsilonGreedy").
"type": "StochasticSampling",
# Add constructor kwargs here (if any).
}

# Materialize only the data in raw format, but not the mapped data b/c
# MARWIL uses a connector to calculate values and therefore the module
Expand Down
12 changes: 12 additions & 0 deletions rllib/algorithms/marwil/tests/test_marwil_old_api_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def test_marwil_compilation_and_learning_from_offline_file(self):

config = (
marwil.MARWILConfig()
.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
.env_runners(num_env_runners=2)
.environment(env="CartPole-v1")
.evaluation(
Expand Down Expand Up @@ -111,6 +115,10 @@ def test_marwil_cont_actions_from_offline_file(self):

config = (
marwil.MARWILConfig()
.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
.env_runners(num_env_runners=1)
.evaluation(
evaluation_num_env_runners=1,
Expand Down Expand Up @@ -148,6 +156,10 @@ def test_marwil_loss_function(self):

config = (
marwil.MARWILConfig()
.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
.env_runners(num_env_runners=0)
.offline_data(input_=[data_file])
) # Learn from offline data.
Expand Down
10 changes: 2 additions & 8 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ class PPOConfig(AlgorithmConfig):

def __init__(self, algo_class=None):
"""Initializes a PPOConfig instance."""
super().__init__(algo_class=algo_class or PPO)

self.exploration_config = {
# The Exploration class to use. In the simplest case, this is the name
# (str) of any class present in the `rllib.utils.exploration` package.
Expand All @@ -122,6 +120,8 @@ def __init__(self, algo_class=None):
# Add constructor kwargs here (if any).
}

super().__init__(algo_class=algo_class or PPO)

# fmt: off
# __sphinx_doc_begin__
self.lr = 5e-5
Expand All @@ -146,12 +146,6 @@ def __init__(self, algo_class=None):

# Override some of AlgorithmConfig's default values with PPO-specific values.
self.num_env_runners = 2

# `.api_stack()`
self.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# __sphinx_doc_end__
# fmt: on

Expand Down
Loading

0 comments on commit 03ea4f6

Please sign in to comment.