Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] CQL: Support multi-GPU/CPU setup and different learning rates for actor, critic, and alpha. #47402

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from typing import Optional, Type, Union

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.sac.sac import (
SAC,
SACConfig,
)
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.sac.sac import (
SAC,
SACConfig,
)
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
Expand Down Expand Up @@ -48,7 +49,7 @@
SAMPLE_TIMER,
TIMERS,
)
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.typing import ResultDict, RLModuleSpecType

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
Expand Down Expand Up @@ -83,7 +84,14 @@ def __init__(self, algo_class=None):
self.lagrangian = False
self.lagrangian_thresh = 5.0
self.min_q_weight = 5.0
self.deterministic_backup = True
self.lr = 3e-4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so for the new stack, users have to set this to None, manually? I guess this is ok (explicit is always good).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. We discussed this in the other PR concerning SAC.

# Note, the new stack defines learning rates for each component.
# The base learning rate `lr` has to be set to `None`, if using
# the new stack.
self.actor_lr = 1e-4,
self.critic_lr = 1e-3
self.alpha_lr = 1e-3

# Changes to Algorithm's/SACConfig's default:

Expand All @@ -105,6 +113,7 @@ def training(
lagrangian: Optional[bool] = NotProvided,
lagrangian_thresh: Optional[float] = NotProvided,
min_q_weight: Optional[float] = NotProvided,
deterministic_backup: Optional[bool] = NotProvided,
**kwargs,
) -> "CQLConfig":
"""Sets the training-related configuration.
Expand All @@ -116,6 +125,8 @@ def training(
lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss).
lagrangian_thresh: Lagrangian threshold.
min_q_weight: in Q weight multiplier.
deterministic_backup: If the target in the Bellman update should have an
entropy backup. Defaults to `True`.

Returns:
This updated AlgorithmConfig object.
Expand All @@ -135,6 +146,8 @@ def training(
self.lagrangian_thresh = lagrangian_thresh
if min_q_weight is not NotProvided:
self.min_q_weight = min_q_weight
if deterministic_backup is not NotProvided:
self.deterministic_backup = deterministic_backup

return self

Expand Down Expand Up @@ -234,6 +247,27 @@ def validate(self) -> None:
"Set this hyperparameter in the `AlgorithmConfig.offline_data`."
)

@override(SACConfig)
def get_default_rl_module_spec(self) -> RLModuleSpecType:
from ray.rllib.algorithms.sac.sac_catalog import SACCatalog

if self.framework_str == "torch":
from ray.rllib.algorithms.cql.torch.cql_torch_rl_module import (
CQLTorchRLModule,
)

return RLModuleSpec(module_class=CQLTorchRLModule, catalog_class=SACCatalog)
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. " "Use `torch`."
)

@property
def _model_config_auto_includes(self):
return super()._model_config_auto_includes | {
"num_actions": self.num_actions,
}


class CQL(SAC):
"""CQL (derived from SAC)."""
Expand Down
Loading
Loading