-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] Add CQLLearner
and CQLTorchLearner
.
#46969
Conversation
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…oss and switched in actor loss from selected actions to sampled actions from the current policy. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
|
||
# Get the current batch size. Note, this size might vary in case the | ||
# last batch contains less than `train_batch_size_per_learner` examples. | ||
batch_size = batch[Columns.OBS].shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we explain this logic here and why we defend against different batch sizes coming in?
Don't we expect always the same batch size from the data pipeline?
If not, we should:
a) explain here why we are expecting various batch sizes
b) probably fix this logic here. What if we call compute_loss_for_module
10x with a batch of train_batch_size_per_learner - 1
(accumulating gradients for these, but not applying these) and then one batch of size train_batch_size_per_learner
. In this case, the effective batch size would be roughly 11x the user configured one, correct?
c) Also, what if the incoming batch is larger than train_batch_size_per_learner
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sven1977 I fully agree on your comment.
a) This was also part of the old stack algorithm. The reason behind this is that when iterating over a dataset the last batch could have less than train_batch_size_per_learner
samples in it.
b) The offline logic here is that after each call to compute_loss_for_module
a compute_gradients
and apply_gradients
will occur. No SGD is run on the offline algorithms (yet). So the train batch size should be always as large as configured (without the last one being smaller).
c) This case should also not happen. iter_batches
takes care of this and ensures that always the batch size is sampled, without the last one which can be avoided by setting a flag, but neglects data.
rllib/algorithms/cql/cql_learner.py
Outdated
# for `alpha`` and the target entropy defined. | ||
super().build() | ||
|
||
# Set up the gradient buffer to store gradients to apply |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add more explanation here as to why caching any grads is necessary?
Maybe I don't understand the data pipelines properly yet, but should we try to fix these to always produce the exact same batch sizes (or number of episode timesteps) per iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not really caused by the batch sizes, but will only be executed if the batch size is as configured. The reason for the grad caching is merely to enable multiple passes through the network during loss calculation.
rllib/algorithms/cql/cql_learner.py
Outdated
# the policy. | ||
# TODO (simon, sven): Add upstream information pieces into this timesteps | ||
# call arg to Learner.update_...(). | ||
self.metrics.log_value( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add: reduce="sum" here to make sure iterations are not EMA'd :) (the default).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you don't need this here anymore. B/c you are logging it below, this entry will be created automatically.
|
||
# TODO (simon, sven): Add upstream information pieces into this timesteps | ||
# call arg to Learner.update_...(). | ||
self.metrics.log_value( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New that reduce="sum" is already defined above, you can just do:
log_value(
(ALL_MODULES, TRAINING_ITERATION),
1,
reduce="sum",
)
alpha = torch.exp(self.curr_log_alpha[module_id]) | ||
# Start training with behavior cloning and turn to the classic Soft-Actor Critic | ||
# after `bc_iters` of training iterations. | ||
if self.metrics.peek((ALL_MODULES, TRAINING_ITERATION)) >= config.bc_iters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to:
self.metrics.peek(([your key]), default=0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To cover the case, in which this entry does not exist yet.
This way, you don't have to define it up front.
Just log into it as reduce="sum" (and value=1) and all will be good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nits left, otherwise good to go! Awesome work @simonsays1980 .
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ven1977. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
QLLearner
and CQLTorchLearner
.
QLLearner
and CQLTorchLearner
.CQLLearner
and CQLTorchLearner
.
Why are these changes needed?
This PR proposes the
Learner
classes forCQL
in the new API stack. It relies on theSACTorchModule
(i.e. having a Q-network and Q-target-network) and implements the loss logic. Because three loss terms are optimized the gradients are collected inside thecompute_loss_for_module
method and are then just returned fromcompute_gradients
. The reason for this are multiple passes through the networks where a follow-up path would erase the gradients computed in the pass before.This PR is one of several PRs that should implement
CQL
in the new API stack.Related issue number
Closes #37779
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.