-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Compile update logic on learner and use cudagraphs #35759
[RLlib] Compile update logic on learner and use cudagraphs #35759
Conversation
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Related tensorboard that shows speedups on rollout worker side: |
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
@@ -123,8 +107,6 @@ def get_state(self) -> Mapping[str, Any]: | |||
@override(RLModule) | |||
def set_state(self, state_dict: Mapping[str, Any]) -> None: | |||
self.load_state_dict(state_dict) | |||
if self._retrace_on_set_weights: | |||
torch._dynamo.reset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this with cudagraphs.
compile_config.compile_forward_train | ||
or compile_config.compile_forward_inference | ||
or compile_config.compile_forward_exploration | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just compile all the forward methods. Ones that are not called, will not be traced anyway.
torch_dynamo_backend: str = "aot_eager" if sys.platform == "darwin" else "inductor" | ||
torch_dynamo_mode: str = "reduce-overhead" | ||
torch_dynamo_backend: str = ( | ||
"aot_eager" if sys.platform == "darwin" else "cudagraphs" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes it so that weight updates actually take effect.
) | ||
self.torch_compile_worker_dynamo_mode = "reduce-overhead" | ||
self.torch_compile_worker_dynamo_mode = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None will make it so that for any chosen backend, we use the default mode.
codagraphs does not have a "reduce-overhead" mode so we need to choose None here.
settings. | ||
"complete_update" promises the highest performance gains, but may work | ||
in some settings. By compiling only forward_train, you may already get | ||
some speedups and avoid issues that arise from compiling the entire update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some cases, there are slight performance differences when compiling forward train vs the complete update.
Until we have explored this and know if we can eliminate one or the other possibility, we can use this switch to choose.
backend=torch_compile_cfg.torch_dynamo_backend, | ||
mode=torch_compile_cfg.torch_dynamo_mode, | ||
**torch_compile_cfg.kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When compiling the update, we need to reset and recompile the whole thing every time we add/remove a module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add this comment to the code plz.
torch_compile_cfg: Optional["TorchCompileConfig"] = None | ||
|
||
def validate(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to expose these parameters to the top algorithm_config. Right now what_to_compile is not surfacing up in algorithm config.
def _get_learner(learning_rate: float = 1e-3) -> Learner: | ||
env = gym.make("CartPole-v1") | ||
# adding learning rate as a configurable parameter to avoid hardcoding it | ||
# and information leakage across tests that rely on knowing the LR value | ||
# that is used in the learner. | ||
learner = get_learner("torch", env, learning_rate=learning_rate) | ||
learner.build() | ||
|
||
return learner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modify get_learner
in rllib.core.testing.utils and use it here?
spec = get_module_spec( | ||
framework="torch", env=env, is_multi_agent=is_multi_agent | ||
) | ||
learner = BCTorchLearner( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are you not using get_learner()?
batch = MultiAgentBatch( | ||
{"another_module": reader.next(), "default_policy": reader.next()}, | ||
0, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you really want to call reader.next()
twice per iteration? is this intentional? you can obtain the batch once and use it in two places.
learner = BCTorchLearner( | ||
module_spec=spec, | ||
framework_hyperparameters=framework_hps, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use get_learner
self._framework_hyperparameters.torch_compile_cfg | ||
) | ||
else: | ||
assert isinstance(self._module, MultiAgentRLModule) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a descriptive error upon failure. e.g. expected type blah got type blah
else: | ||
assert isinstance(self._module, MultiAgentRLModule) | ||
for module in self._module._rl_modules.values(): | ||
module.compile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need to skip the module if it's not a TorchRLModule (e.g. it could be a RandomRLModule, neither torch nor TF)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in other words, only compile those that are TorchRLModule
backend=torch_compile_cfg.torch_dynamo_backend, | ||
mode=torch_compile_cfg.torch_dynamo_mode, | ||
**torch_compile_cfg.kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add this comment to the code plz.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make _dynamo_is_available()
a util under torch_utils.py?
@@ -272,7 +272,9 @@ def compile_me(input_dict): | |||
|
|||
import torch._dynamo as dynamo | |||
|
|||
# This is a helper method of dynamo to analyze where breaks occur. | |||
dynamo_explanation = dynamo.explain(compile_me, {"in": torch.Tensor([[1]])}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comments above apply here :)
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
@kouroshHakha I've also added a configuration enumerator instead of relying on two long strings "complete_update" and "forward_train". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just a super quick nit for Enums
https://realpython.com/python-enum/#getting-to-know-enumerations-in-python
rllib/algorithms/algorithm_config.py
Outdated
@@ -282,6 +283,9 @@ def __init__(self, algo_class=None): | |||
} | |||
# Torch compile settings | |||
self.torch_compile_learner = False | |||
self.torch_compile_learner_what_to_compile = ( | |||
TorchCompileWhatToCompile.forward_train |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enums should usually be all CAPS. e.g. FORWARD_TRAIN
rllib/algorithms/algorithm_config.py
Outdated
torch_compile_learner_what_to_compile: A string specifying what to | ||
compile on the learner side if torch_compile_learner is True. | ||
This can be one of the following: | ||
- TorchCompileWhatToCompile.complete_update: Compile the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just say see TorchCompileWhatToCompile for available options. This way you don't have to duplicate docstring if things change later down the line.
rllib/algorithms/algorithm_config.py
Outdated
forward_train method, the loss calculation and the optimizer step | ||
together on the TorchLearner. | ||
- TorchCompileWhatToCompile.forward_train: Compile only forward train. | ||
Note: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use ..Note:: directives and check the rendered documentation to see if the formating is correctly rendered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
…ct#35759) Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com> Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
Why are these changes needed?
In the first attempt to leverage torch compile, we didn't introduce a compiled update method on the side of the learner (1) and also had little success with torch compiling on the rollout worker side because weight updates would effectively not happen when we would compile (2).
For (1): This PR makes an attempt at compiling on the learner side akin to what we do for eager tracing, meaning that there is a
possibly_compiled_update()
method on the TorchLearner side that we introduce.For (2): We get around the issue of not being able to set weights by using cudagraphs as the torch dynamo backend.