-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Compile update logic on learner and use cudagraphs #35759
Changes from 4 commits
8094d97
59194eb
3f72f76
925e831
af5e221
b1acafd
5c21292
6129cf6
bceb68e
2e29c37
15bebf5
cbf72ae
9e983b2
0bb631b
0f6f7a7
49426e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec | ||
from ray.rllib.core.rl_module.rl_module import ModuleID, SingleAgentRLModuleSpec | ||
from ray.rllib.env.env_context import EnvContext | ||
from ray.rllib.core.learner.learner import TorchCompileWhatToCompile | ||
from ray.rllib.env.multi_agent_env import MultiAgentEnv | ||
from ray.rllib.env.wrappers.atari_wrappers import is_atari | ||
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector | ||
|
@@ -282,6 +283,9 @@ def __init__(self, algo_class=None): | |
} | ||
# Torch compile settings | ||
self.torch_compile_learner = False | ||
self.torch_compile_learner_what_to_compile = ( | ||
TorchCompileWhatToCompile.forward_train | ||
) | ||
self.torch_compile_learner_dynamo_backend = ( | ||
"aot_eager" if sys.platform == "darwin" else "inductor" | ||
) | ||
|
@@ -1228,6 +1232,7 @@ def framework( | |
tf_session_args: Optional[Dict[str, Any]] = NotProvided, | ||
local_tf_session_args: Optional[Dict[str, Any]] = NotProvided, | ||
torch_compile_learner: Optional[bool] = NotProvided, | ||
torch_compile_learner_what_to_compile: Optional[str] = NotProvided, | ||
torch_compile_learner_dynamo_mode: Optional[str] = NotProvided, | ||
torch_compile_learner_dynamo_backend: Optional[str] = NotProvided, | ||
torch_compile_worker: Optional[bool] = NotProvided, | ||
|
@@ -1254,8 +1259,22 @@ def framework( | |
local_tf_session_args: Override the following tf session args on the local | ||
worker | ||
torch_compile_learner: If True, forward_train methods on TorchRLModules | ||
on the learner are compiled. If not specified, the default is to compile | ||
forward train on the learner. | ||
on the learner are compiled. If not specified, the default is to compile | ||
forward train on the learner. | ||
torch_compile_learner_what_to_compile: A string specifying what to | ||
compile on the learner side if torch_compile_learner is True. | ||
This can be one of the following: | ||
- TorchCompileWhatToCompile.complete_update: Compile the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just say see TorchCompileWhatToCompile for available options. This way you don't have to duplicate docstring if things change later down the line. |
||
forward_train method, the loss calculation and the optimizer step | ||
together on the TorchLearner. | ||
- TorchCompileWhatToCompile.forward_train: Compile only forward train. | ||
Note: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use ..Note:: directives and check the rendered documentation to see if the formating is correctly rendered. |
||
- torch.compiled code can become slow on graph breaks or even raise | ||
errors on unsupported operations. Empirically, compiling | ||
`forward_train` should introduce little graph breaks, raise no | ||
errors but result in a speedup comparable to compiling the | ||
complete update. | ||
- Using `complete_update` is experimental and may result in errors. | ||
torch_compile_learner_dynamo_backend: The torch dynamo backend to use on | ||
the learner. | ||
torch_compile_learner_dynamo_mode: The torch dynamo mode to use on the | ||
|
@@ -1297,6 +1316,10 @@ def framework( | |
) | ||
if torch_compile_learner_dynamo_mode is not NotProvided: | ||
self.torch_compile_learner_dynamo_mode = torch_compile_learner_dynamo_mode | ||
if torch_compile_learner_what_to_compile is not NotProvided: | ||
self.torch_compile_learner_what_to_compile = ( | ||
torch_compile_learner_what_to_compile | ||
) | ||
if torch_compile_worker is not NotProvided: | ||
self.torch_compile_worker = torch_compile_worker | ||
if torch_compile_worker_dynamo_backend is not NotProvided: | ||
|
@@ -3336,6 +3359,7 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi | |
config.framework( | ||
torch_compile=self.torch_compile_learner, | ||
torch_compile_cfg=self.get_torch_compile_learner_config(), | ||
torch_compile_what_to_compile=self.torch_compile_learner_what_to_compile, # noqa: E501 | ||
) | ||
elif self.framework_str == "tf2": | ||
config.framework(eager_tracing=self.eager_tracing) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import logging | ||
import pathlib | ||
from collections import defaultdict | ||
from enum import Enum | ||
from dataclasses import dataclass, field | ||
from typing import ( | ||
Any, | ||
|
@@ -82,6 +83,18 @@ | |
LEARNER_RESULTS_CURR_LR_KEY = "curr_lr" | ||
|
||
|
||
class TorchCompileWhatToCompile(str, Enum): | ||
"""Enumerates schemes of what parts of the TorchLearner can be compiled.""" | ||
|
||
# Compile the entire update step of the learner. | ||
# This includes the forward pass of the RLModule, the loss computation, and the | ||
# optimizer step. | ||
complete_update = "complete_update" | ||
# Only compile the forward methods (and therein the forward_train method) of the | ||
# RLModule. | ||
forward_train = "forward_train" | ||
|
||
|
||
@dataclass | ||
class FrameworkHyperparameters: | ||
"""The framework specific hyper-parameters. | ||
|
@@ -95,15 +108,16 @@ class FrameworkHyperparameters: | |
torch_compile: Whether to use torch.compile() within the context of a given | ||
learner. | ||
what_to_compile: What to compile when using torch.compile(). Can be one of | ||
["complete_update", "forward_train"]. | ||
If "complete_update", the update step of the learner will be compiled. This | ||
[TorchCompileWhatToCompile.complete_update, | ||
TorchCompileWhatToCompile.forward_train]. | ||
If `complete_update`, the update step of the learner will be compiled. This | ||
includes the forward pass of the RLModule, the loss computation, and the | ||
optimizer step. | ||
If "forward_train", only the forward methods (and therein the | ||
If `forward_train`, only the forward methods (and therein the | ||
forward_train method) of the RLModule will be compiled. | ||
Either of the two may lead to different performance gains in different | ||
settings. | ||
"complete_update" promises the highest performance gains, but may work | ||
`complete_update` promises the highest performance gains, but may not work | ||
in some settings. By compiling only forward_train, you may already get | ||
some speedups and avoid issues that arise from compiling the entire update. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In some cases, there are slight performance differences when compiling forward train vs the complete update. |
||
troch_compile_config: The TorchCompileConfig to use for compiling the RL | ||
|
@@ -112,14 +126,25 @@ class FrameworkHyperparameters: | |
|
||
eager_tracing: bool = False | ||
torch_compile: bool = False | ||
what_to_compile: str = "complete_update" | ||
what_to_compile: str = TorchCompileWhatToCompile.forward_train | ||
torch_compile_cfg: Optional["TorchCompileConfig"] = None | ||
|
||
def validate(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to expose these parameters to the top algorithm_config. Right now what_to_compile is not surfacing up in algorithm config. |
||
if self.what_to_compile not in ["complete_update", "forward_train"]: | ||
raise ValueError( | ||
"what_to_compile must be one of ['complete_update', 'forward_train']." | ||
) | ||
if self.torch_compile: | ||
if self.what_to_compile not in [ | ||
TorchCompileWhatToCompile.forward_train, | ||
TorchCompileWhatToCompile.complete_update, | ||
]: | ||
raise ValueError( | ||
f"what_to_compile must be one of [" | ||
f"TorchCompileWhatToCompile.forward_train, " | ||
f"TorchCompileWhatToCompile.complete_update] but is" | ||
f" {self.what_to_compile}" | ||
) | ||
if self.torch_compile_cfg is None: | ||
raise ValueError( | ||
"torch_compile_cfg must be set when torch_compile is True." | ||
) | ||
|
||
|
||
@dataclass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enums should usually be all CAPS. e.g. FORWARD_TRAIN