-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Compile update logic on learner and use cudagraphs #35759
Changes from 10 commits
8094d97
59194eb
3f72f76
925e831
af5e221
b1acafd
5c21292
6129cf6
bceb68e
2e29c37
15bebf5
cbf72ae
9e983b2
0bb631b
0f6f7a7
49426e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,13 +92,35 @@ class FrameworkHyperparameters: | |
This is useful for speeding up the training loop. However, it is not | ||
compatible with all tf operations. For example, tf.print is not supported | ||
in tf.function. | ||
torch_compile: Whether to use torch.compile() within the context of a given | ||
learner. | ||
what_to_compile: What to compile when using torch.compile(). Can be one of | ||
["complete_update", "forward_train"]. | ||
If "complete_update", the update step of the learner will be compiled. This | ||
includes the forward pass of the RLModule, the loss computation, and the | ||
optimizer step. | ||
If "forward_train", only the forward methods (and therein the | ||
forward_train method) of the RLModule will be compiled. | ||
Either of the two may lead to different performance gains in different | ||
settings. | ||
"complete_update" promises the highest performance gains, but may work | ||
in some settings. By compiling only forward_train, you may already get | ||
some speedups and avoid issues that arise from compiling the entire update. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In some cases, there are slight performance differences when compiling forward train vs the complete update. |
||
troch_compile_config: The TorchCompileConfig to use for compiling the RL | ||
Module in Torch. | ||
""" | ||
|
||
eager_tracing: bool = False | ||
torch_compile: bool = False | ||
what_to_compile: str = "complete_update" | ||
torch_compile_cfg: Optional["TorchCompileConfig"] = None | ||
|
||
def validate(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to expose these parameters to the top algorithm_config. Right now what_to_compile is not surfacing up in algorithm config. |
||
if self.what_to_compile not in ["complete_update", "forward_train"]: | ||
raise ValueError( | ||
"what_to_compile must be one of ['complete_update', 'forward_train']." | ||
) | ||
|
||
|
||
@dataclass | ||
class LearnerHyperparameters: | ||
|
@@ -314,6 +336,7 @@ def __init__( | |
self._framework_hyperparameters = ( | ||
framework_hyperparameters or FrameworkHyperparameters() | ||
) | ||
self._framework_hyperparameters.validate() | ||
|
||
# whether self.build has already been called | ||
self._is_built = False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test will get skipped on CI right? Can we add that as a comment / TODO? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be clear from the @unittest.skipIf() above. I don't think we should add this as a comment because the comment can become out of date quickly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fair |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import itertools | ||
import unittest | ||
|
||
import gymnasium as gym | ||
|
||
import ray | ||
from ray.rllib.core.learner.learner import FrameworkHyperparameters | ||
from ray.rllib.core.learner.learner import Learner | ||
from ray.rllib.core.models.tests.test_base_models import _dynamo_is_available | ||
from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig | ||
from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner | ||
from ray.rllib.core.testing.utils import get_learner | ||
from ray.rllib.core.testing.utils import get_module_spec | ||
from ray.rllib.policy.sample_batch import MultiAgentBatch | ||
from ray.rllib.utils.test_utils import get_cartpole_dataset_reader | ||
|
||
|
||
def _get_learner(learning_rate: float = 1e-3) -> Learner: | ||
env = gym.make("CartPole-v1") | ||
# adding learning rate as a configurable parameter to avoid hardcoding it | ||
# and information leakage across tests that rely on knowing the LR value | ||
# that is used in the learner. | ||
learner = get_learner("torch", env, learning_rate=learning_rate) | ||
learner.build() | ||
|
||
return learner | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. modify |
||
|
||
|
||
class TestLearner(unittest.TestCase): | ||
@classmethod | ||
def setUp(cls) -> None: | ||
ray.init() | ||
|
||
@classmethod | ||
def tearDown(cls) -> None: | ||
ray.shutdown() | ||
|
||
@unittest.skipIf(not _dynamo_is_available(), "torch._dynamo not available") | ||
def test_torch_compile(self): | ||
"""Test if torch.compile() can be applied and used on the learner. | ||
|
||
Also tests if we can update with the compiled update method without errors. | ||
""" | ||
|
||
env = gym.make("CartPole-v1") | ||
is_multi_agents = [False, True] | ||
what_to_compiles = ["complete_update", "forward_train"] | ||
|
||
for is_multi_agent, what_to_compile in itertools.product( | ||
is_multi_agents, what_to_compiles | ||
): | ||
framework_hps = FrameworkHyperparameters( | ||
torch_compile=True, | ||
torch_compile_cfg=TorchCompileConfig(), | ||
what_to_compile=what_to_compile, | ||
) | ||
spec = get_module_spec( | ||
framework="torch", env=env, is_multi_agent=is_multi_agent | ||
) | ||
learner = BCTorchLearner( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are you not using get_learner()? |
||
module_spec=spec, | ||
framework_hyperparameters=framework_hps, | ||
) | ||
learner.build() | ||
|
||
reader = get_cartpole_dataset_reader(batch_size=512) | ||
|
||
for iter_i in range(10): | ||
batch = reader.next() | ||
learner.update(batch.as_multi_agent()) | ||
|
||
spec = get_module_spec(framework="torch", env=env, is_multi_agent=False) | ||
learner.add_module(module_id="another_module", module_spec=spec) | ||
|
||
for iter_i in range(10): | ||
batch = MultiAgentBatch( | ||
{"another_module": reader.next(), "default_policy": reader.next()}, | ||
0, | ||
) | ||
Comment on lines
+70
to
+73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you really want to call |
||
learner.update(batch) | ||
|
||
learner.remove_module(module_id="another_module") | ||
|
||
@unittest.skipIf(not _dynamo_is_available(), "torch._dynamo not available") | ||
def test_torch_compile_no_breaks(self): | ||
"""Tests if torch.compile() does encounter too many breaks. | ||
|
||
torch.compile() should ideally not encounter any breaks when compiling the | ||
update method of the learner. This method tests if we encounter only a given | ||
number of breaks. | ||
""" | ||
|
||
env = gym.make("CartPole-v1") | ||
framework_hps = FrameworkHyperparameters( | ||
torch_compile=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch_compile shouldn't be true here? can you explain? |
||
torch_compile_cfg=TorchCompileConfig(), | ||
) | ||
|
||
spec = get_module_spec(framework="torch", env=env) | ||
learner = BCTorchLearner( | ||
module_spec=spec, | ||
framework_hyperparameters=framework_hps, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use get_learner |
||
learner.build() | ||
|
||
import torch._dynamo as dynamo | ||
|
||
reader = get_cartpole_dataset_reader(batch_size=512) | ||
|
||
batch = reader.next().as_multi_agent() | ||
batch = learner._convert_batch_type(batch) | ||
|
||
# This is a helper method of dynamo to analyze where breaks occur. | ||
dynamo_explanation = dynamo.explain(learner._update, batch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like on the newer version they output a dataclass. You either want to use that and make sure we pin the torch version in your tests (and skip or error) even if the version is not consistent.
If you use the tuple return version please don't hard code indices. instead use the named tuple assignment
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just saw that the tuples change between versions, too. |
||
print(dynamo_explanation[5]) | ||
|
||
# There should be only one break reason - `return_value` - since inputs and | ||
# outputs are not checked | ||
break_reasons_list = dynamo_explanation[4] | ||
|
||
# TODO(Artur): Attempt bringing breaks down to 1. (This may not be possible) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is len always gonna be three? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently it is. |
||
self.assertEquals(len(break_reasons_list), 3) | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
import sys | ||
|
||
sys.exit(pytest.main(["-v", __file__])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None will make it so that for any chosen backend, we use the default mode.
codagraphs does not have a "reduce-overhead" mode so we need to choose None here.