-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimized polyak updates #106
Conversation
stable_baselines3/common/utils.py
Outdated
""" | ||
params = list(params) | ||
target_params = list(target_params) | ||
if params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe an assert instead? "assert len(params) > 0" + meaningful error message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"the if is simply to get the device" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind my previous reply, the if was garbage left from a different implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous implementation used the first parameter to get the device and create a tensor with tau
, but this version doesn't need it. I forgot about it and accidentally left it in.
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM ;)
Just minor remark and waiting for the num_threads=1
case ;)
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
@partiallytyped i think we are done with that one, no? |
LGTM :) |
I will just do a quick performance check (one seed on two bullet env with SAC) and then merge ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks =)
Optimized Polyak updates through in-place operations.
Description
Used
th.no_grad
to avoid storing the graph. Used in-placeth.Tensor.mul
andth.add
operations to avoid creating extra tensors.Achieves a speedup of 1.5 and 1.95 for cpus and gpus in google colab over the old version.
Motivation and Context
Closes #93
Types of changes
Checklist:
make lint
make pytest
andmake type
both pass.