Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] CQL: Support multi-GPU/CPU setup and different learning rates for actor, critic, and alpha. #47402

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from typing import Optional, Type, Union

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.sac.sac import (
SAC,
SACConfig,
)
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.sac.sac import (
SAC,
SACConfig,
)
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
Expand Down Expand Up @@ -48,7 +49,7 @@
SAMPLE_TIMER,
TIMERS,
)
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.typing import ResultDict, RLModuleSpecType

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
Expand Down Expand Up @@ -84,6 +85,12 @@ def __init__(self, algo_class=None):
self.lagrangian_thresh = 5.0
self.min_q_weight = 5.0
self.lr = 3e-4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so for the new stack, users have to set this to None, manually? I guess this is ok (explicit is always good).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. We discussed this in the other PR concerning SAC.

# Note, the new stack defines learning rates for each component.
# The base learning rate `lr` has to be set to `None`, if using
# the new stack.
self.actor_lr = 2e-4,
self.critic_lr = 8e-4
self.alpha_lr = 9e-4

# Changes to Algorithm's/SACConfig's default:

Expand Down Expand Up @@ -234,6 +241,28 @@ def validate(self) -> None:
"Set this hyperparameter in the `AlgorithmConfig.offline_data`."
)

@override(SACConfig)
def get_default_rl_module_spec(self) -> RLModuleSpecType:
from ray.rllib.algorithms.sac.sac_catalog import SACCatalog

if self.framework_str == "torch":
from ray.rllib.algorithms.cql.torch.cql_torch_rl_module import (
CQLTorchRLModule,
)

return RLModuleSpec(module_class=CQLTorchRLModule, catalog_class=SACCatalog)
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. " "Use `torch`."
)

@property
def _model_config_auto_includes(self):
return super()._model_config_auto_includes | {
"num_actions": self.num_actions,
"_deterministic_loss": self._deterministic_loss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Let's remove this deterministic loss thing. It's a relic from a long time ago (2020) when I was trying to debug SAC on torch vs our old SAC on tf. It serves no real purpose and just bloats the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!! That saves us many lines of code!

}


class CQL(SAC):
"""CQL (derived from SAC)."""
Expand Down
225 changes: 18 additions & 207 deletions rllib/algorithms/cql/torch/cql_torch_learner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import tree
from typing import Dict

from ray.air.constants import TRAINING_ITERATION
Expand Down Expand Up @@ -74,10 +73,6 @@ def compute_loss_for_module(
* (logps_curr.detach() + self.target_entropy[module_id])
)

# Get the current batch size. Note, this size might vary in case the
# last batch contains less than `train_batch_size_per_learner` examples.
batch_size = batch[Columns.OBS].shape[0]

# Get the current alpha.
alpha = torch.exp(self.curr_log_alpha[module_id])
# Start training with behavior cloning and turn to the classic Soft-Actor Critic
Expand Down Expand Up @@ -105,37 +100,17 @@ def compute_loss_for_module(

# The critic loss is composed of the standard SAC Critic L2 loss and the
# CQL entropy loss.
action_dist_next = action_dist_class.from_logits(
fwd_out["action_dist_inputs_next"]
)
# Sample the actions for the next state.
actions_next = (
# Note, we do not need to backpropagate through the
# next actions.
action_dist_next.sample()
if not config._deterministic_loss
else action_dist_next.to_deterministic().sample()
)

# Get the Q-values for the actually selected actions in the offline data.
# In the critic loss we use these as predictions.
q_selected = fwd_out[QF_PREDS]
if config.twin_q:
q_twin_selected = fwd_out[QF_TWIN_PREDS]

# Compute Q-values from the target Q network for the next state with the
# sampled actions for the next state.
q_batch_next = {
Columns.OBS: batch[Columns.NEXT_OBS],
Columns.ACTIONS: actions_next,
}
# Note, if `twin_q` is `True`, `SACTorchRLModule.forward_target` calculates
# the Q-values for both, `qf_target` and `qf_twin_target` and
# returns the minimum.
q_target_next = self.module[module_id].forward_target(q_batch_next)

# Now mask all Q-values with terminating next states in the targets.
q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_target_next
q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * fwd_out[
"q_target_next"
]

# Compute the right hand side of the Bellman equation. Detach this node
# from the computation graph as we do not want to backpropagate through
Expand Down Expand Up @@ -171,121 +146,19 @@ def compute_loss_for_module(

# Now calculate the CQL loss (we use the entropy version of the CQL algorithm).
# Note, the entropy version performs best in shown experiments.
# Generate random actions (from the mu distribution as named in Kumar et
# al. (2020))
low = torch.tensor(
self.module[module_id].config.action_space.low,
device=fwd_out[QF_PREDS].device,
)
high = torch.tensor(
self.module[module_id].config.action_space.high,
device=fwd_out[QF_PREDS].device,
)
num_samples = batch[Columns.ACTIONS].shape[0] * config.num_actions
actions_rand_repeat = low + (high - low) * torch.rand(
(num_samples, low.shape[0]), device=fwd_out[QF_PREDS].device
)

# Sample current and next actions (from the pi distribution as named in Kumar
# et al. (2020)) using repeated observations.
actions_curr_repeat, logps_curr_repeat, obs_curr_repeat = self._repeat_actions(
action_dist_class, batch[Columns.OBS], config.num_actions, module_id
)
actions_next_repeat, logps_next_repeat, obs_next_repeat = self._repeat_actions(
action_dist_class, batch[Columns.NEXT_OBS], config.num_actions, module_id
)

# Calculate the Q-values for all actions.
batch_rand_repeat = {
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_rand_repeat,
}
# Note, we need here the Q-values from the base Q-value function
# and not the minimum with an eventual Q-value twin.
q_rand_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_rand_repeat,
self.module[module_id].qf_encoder,
self.module[module_id].qf,
)
.view(batch_size, config.num_actions, 1)
)
# Calculate twin Q-values for the random actions, if needed.
if config.twin_q:
q_twin_rand_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_rand_repeat,
self.module[module_id].qf_twin_encoder,
self.module[module_id].qf_twin,
)
.view(batch_size, config.num_actions, 1)
)
del batch_rand_repeat
batch_curr_repeat = {
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_curr_repeat,
}
q_curr_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_curr_repeat,
self.module[module_id].qf_encoder,
self.module[module_id].qf,
)
.view(batch_size, config.num_actions, 1)
)
# Calculate twin Q-values for the repeated actions from the current policy,
# if needed.
if config.twin_q:
q_twin_curr_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_curr_repeat,
self.module[module_id].qf_twin_encoder,
self.module[module_id].qf_twin,
)
.view(batch_size, config.num_actions, 1)
)
del batch_curr_repeat
batch_next_repeat = {
# Note, we use here the current observations b/c we want to keep the
# state fix while sampling the actions.
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_next_repeat,
}
q_next_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_next_repeat,
self.module[module_id].qf_encoder,
self.module[module_id].qf,
)
.view(batch_size, config.num_actions, 1)
)
# Calculate also the twin Q-values for the current policy and next actions,
# if needed.
if config.twin_q:
q_twin_next_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_next_repeat,
self.module[module_id].qf_twin_encoder,
self.module[module_id].qf_twin,
)
.view(batch_size, config.num_actions, 1)
)
del batch_next_repeat

# Compute the log-probabilities for the random actions.
# Compute the log-probabilities for the random actions (note, we generate random
# actions (from the mu distribution as named in Kumar et al. (2020))).
# Note, all actions, action log-probabilities and Q-values are already computed
# by the module's `_forward_train` method.
# TODO (simon): This is the density for a discrete uniform, however, actions
# come from a continuous one. So actually this density should use (1/(high-low))
# instead of (1/2).
random_density = torch.log(
torch.pow(
torch.tensor(
actions_curr_repeat.shape[-1], device=actions_curr_repeat.device
fwd_out["actions_curr_repeat"].shape[-1],
device=fwd_out["actions_curr_repeat"].device,
),
0.5,
)
Expand All @@ -294,9 +167,9 @@ def compute_loss_for_module(
# entropy version of CQL).
q_repeat = torch.cat(
[
q_rand_repeat - random_density,
q_next_repeat - logps_next_repeat.detach(),
q_curr_repeat - logps_curr_repeat.detach(),
fwd_out["q_rand_repeat"] - random_density,
fwd_out["q_next_repeat"] - fwd_out["logps_next_repeat"],
fwd_out["q_curr_repeat"] - fwd_out["logps_curr_repeat"],
],
dim=1,
)
Expand All @@ -313,9 +186,9 @@ def compute_loss_for_module(
if config.twin_q:
q_twin_repeat = torch.cat(
[
q_twin_rand_repeat - random_density,
q_twin_next_repeat - logps_next_repeat.detach(),
q_twin_curr_repeat - logps_curr_repeat.detach(),
fwd_out["q_twin_rand_repeat"] - random_density,
fwd_out["q_twin_next_repeat"] - fwd_out["logps_next_repeat"],
fwd_out["q_twin_curr_repeat"] - fwd_out["logps_curr_repeat"],
],
dim=1,
)
Expand Down Expand Up @@ -352,9 +225,9 @@ def compute_loss_for_module(
"target_entropy": self.target_entropy[module_id],
"actions_curr_policy": torch.mean(actions_curr),
LOGPS_KEY: torch.mean(logps_curr),
QF_MEAN_KEY: torch.mean(q_curr_repeat),
QF_MAX_KEY: torch.max(q_curr_repeat),
QF_MIN_KEY: torch.min(q_curr_repeat),
QF_MEAN_KEY: torch.mean(fwd_out["q_curr_repeat"]),
QF_MAX_KEY: torch.max(fwd_out["q_curr_repeat"]),
QF_MIN_KEY: torch.min(fwd_out["q_curr_repeat"]),
TD_ERROR_MEAN_KEY: torch.mean(td_error),
},
key=module_id,
Expand Down Expand Up @@ -406,65 +279,3 @@ def compute_gradients(
)

return grads

def _repeat_tensor(self, tensor, repeat):
"""Generates a repeated version of a tensor.

The repetition is done similar `np.repeat` and repeats each value
instead of the complete vector.

Args:
tensor: The tensor to be repeated.
repeat: How often each value in the tensor should be repeated.

Returns:
A tensor holding `repeat` repeated values of the input `tensor`
"""
# Insert the new dimension at axis 1 into the tensor.
t_repeat = tensor.unsqueeze(1)
# Repeat the tensor along the new dimension.
t_repeat = torch.repeat_interleave(t_repeat, repeat, dim=1)
# Stack the repeated values into the batch dimension.
t_repeat = t_repeat.view(-1, *tensor.shape[1:])
# Return the repeated tensor.
return t_repeat

def _repeat_actions(self, action_dist_class, obs, num_actions, module_id):
"""Generated actions for repeated observations.

The `num_actions` define a multiplier used for generating `num_actions`
as many actions as the batch size. Observations are repeated and then a
model forward pass is made.

Args:
action_dist_class: The action distribution class to be sued for sampling
actions.
obs: A batched observation tensor.
num_actions: The multiplier for actions, i.e. how much more actions
than the batch size should be generated.
module_id: The module ID to be used when calling the forward pass.

Returns:
A tuple containing the sampled actions, their log-probabilities and the
repeated observations.
"""
# Receive the batch size.
batch_size = obs.shape[0]
# Repeat the observations `num_actions` times.
obs_repeat = tree.map_structure(
lambda t: self._repeat_tensor(t, num_actions), obs
)
# Generate a batch for the forward pass.
temp_batch = {Columns.OBS: obs_repeat}
# Run the forward pass in inference mode.
fwd_out = self.module[module_id].forward_inference(temp_batch)
# Generate the squashed Gaussian from the model's logits.
action_dist = action_dist_class.from_logits(fwd_out[Columns.ACTION_DIST_INPUTS])
# Sample the actions. Note, we want to make a backward pass through
# these actions.
actions = action_dist.rsample()
# Compute the action log-probabilities.
action_logps = action_dist.logp(actions).view(batch_size, num_actions, 1)

# Return
return actions, action_logps, obs_repeat
Loading