Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Compile update logic on learner and use cudagraphs #35759

Merged

Conversation

ArturNiederfahrenhorst
Copy link
Contributor

@ArturNiederfahrenhorst ArturNiederfahrenhorst commented May 25, 2023

Why are these changes needed?

In the first attempt to leverage torch compile, we didn't introduce a compiled update method on the side of the learner (1) and also had little success with torch compiling on the rollout worker side because weight updates would effectively not happen when we would compile (2).

For (1): This PR makes an attempt at compiling on the learner side akin to what we do for eager tracing, meaning that there is a possibly_compiled_update() method on the TorchLearner side that we introduce.

For (2): We get around the issue of not being able to set weights by using cudagraphs as the torch dynamo backend.

Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
ArturNiederfahrenhorst and others added 2 commits May 24, 2023 21:30
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
@ArturNiederfahrenhorst ArturNiederfahrenhorst marked this pull request as ready for review May 26, 2023 21:13
@ArturNiederfahrenhorst ArturNiederfahrenhorst changed the title [RLlib] Second iteration of torch.compile() changes [RLlib] Compile update logic on learner and use cudagraphs May 26, 2023
@ArturNiederfahrenhorst
Copy link
Contributor Author

ArturNiederfahrenhorst commented May 26, 2023

Related tensorboard that shows speedups on rollout worker side:

ArturNiederfahrenhorst and others added 6 commits May 30, 2023 15:42
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
@@ -123,8 +107,6 @@ def get_state(self) -> Mapping[str, Any]:
@override(RLModule)
def set_state(self, state_dict: Mapping[str, Any]) -> None:
self.load_state_dict(state_dict)
if self._retrace_on_set_weights:
torch._dynamo.reset()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this with cudagraphs.

compile_config.compile_forward_train
or compile_config.compile_forward_inference
or compile_config.compile_forward_exploration
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just compile all the forward methods. Ones that are not called, will not be traced anyway.

torch_dynamo_backend: str = "aot_eager" if sys.platform == "darwin" else "inductor"
torch_dynamo_mode: str = "reduce-overhead"
torch_dynamo_backend: str = (
"aot_eager" if sys.platform == "darwin" else "cudagraphs"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes it so that weight updates actually take effect.

)
self.torch_compile_worker_dynamo_mode = "reduce-overhead"
self.torch_compile_worker_dynamo_mode = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None will make it so that for any chosen backend, we use the default mode.
codagraphs does not have a "reduce-overhead" mode so we need to choose None here.

settings.
"complete_update" promises the highest performance gains, but may work
in some settings. By compiling only forward_train, you may already get
some speedups and avoid issues that arise from compiling the entire update.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases, there are slight performance differences when compiling forward train vs the complete update.
Until we have explored this and know if we can eliminate one or the other possibility, we can use this switch to choose.

backend=torch_compile_cfg.torch_dynamo_backend,
mode=torch_compile_cfg.torch_dynamo_mode,
**torch_compile_cfg.kwargs,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When compiling the update, we need to reset and recompile the whole thing every time we add/remove a module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this comment to the code plz.

torch_compile_cfg: Optional["TorchCompileConfig"] = None

def validate(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to expose these parameters to the top algorithm_config. Right now what_to_compile is not surfacing up in algorithm config.

Comment on lines 18 to 26
def _get_learner(learning_rate: float = 1e-3) -> Learner:
env = gym.make("CartPole-v1")
# adding learning rate as a configurable parameter to avoid hardcoding it
# and information leakage across tests that rely on knowing the LR value
# that is used in the learner.
learner = get_learner("torch", env, learning_rate=learning_rate)
learner.build()

return learner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify get_learner in rllib.core.testing.utils and use it here?

spec = get_module_spec(
framework="torch", env=env, is_multi_agent=is_multi_agent
)
learner = BCTorchLearner(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you not using get_learner()?

Comment on lines +76 to +79
batch = MultiAgentBatch(
{"another_module": reader.next(), "default_policy": reader.next()},
0,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you really want to call reader.next() twice per iteration? is this intentional? you can obtain the batch once and use it in two places.

Comment on lines 100 to 103
learner = BCTorchLearner(
module_spec=spec,
framework_hyperparameters=framework_hps,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use get_learner

self._framework_hyperparameters.torch_compile_cfg
)
else:
assert isinstance(self._module, MultiAgentRLModule)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a descriptive error upon failure. e.g. expected type blah got type blah

else:
assert isinstance(self._module, MultiAgentRLModule)
for module in self._module._rl_modules.values():
module.compile(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to skip the module if it's not a TorchRLModule (e.g. it could be a RandomRLModule, neither torch nor TF)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in other words, only compile those that are TorchRLModule

backend=torch_compile_cfg.torch_dynamo_backend,
mode=torch_compile_cfg.torch_dynamo_mode,
**torch_compile_cfg.kwargs,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this comment to the code plz.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make _dynamo_is_available() a util under torch_utils.py?

@@ -272,7 +272,9 @@ def compile_me(input_dict):

import torch._dynamo as dynamo

# This is a helper method of dynamo to analyze where breaks occur.
dynamo_explanation = dynamo.explain(compile_me, {"in": torch.Tensor([[1]])})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments above apply here :)

Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
@ArturNiederfahrenhorst
Copy link
Contributor Author

@kouroshHakha I've also added a configuration enumerator instead of relying on two long strings "complete_update" and "forward_train".

Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -282,6 +283,9 @@ def __init__(self, algo_class=None):
}
# Torch compile settings
self.torch_compile_learner = False
self.torch_compile_learner_what_to_compile = (
TorchCompileWhatToCompile.forward_train
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enums should usually be all CAPS. e.g. FORWARD_TRAIN

torch_compile_learner_what_to_compile: A string specifying what to
compile on the learner side if torch_compile_learner is True.
This can be one of the following:
- TorchCompileWhatToCompile.complete_update: Compile the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just say see TorchCompileWhatToCompile for available options. This way you don't have to duplicate docstring if things change later down the line.

forward_train method, the loss calculation and the optimizer step
together on the TorchLearner.
- TorchCompileWhatToCompile.forward_train: Compile only forward train.
Note:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ..Note:: directives and check the rendered documentation to see if the formating is correctly rendered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair

Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
@kouroshHakha kouroshHakha merged commit 2a12cf5 into ray-project:master Jun 21, 2023
arvind-chandra pushed a commit to lmco/ray that referenced this pull request Aug 31, 2023
…ct#35759)

Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants