Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Compile update logic on learner and use cudagraphs #35759

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,13 @@ py_test(
srcs = ["core/learner/tests/test_learner.py"]
)

py_test(
name = "test_torch_learner_compile",
tags = ["team:rllib", "core", "ray_data"],
size = "medium",
srcs = ["core/learner/torch/tests/test_torch_learner_compile.py"]
)

py_test(
name ="tests/test_algorithm_save_load_checkpoint_learner",
tags = ["team:rllib", "core"],
Expand Down
14 changes: 7 additions & 7 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,12 @@ def __init__(self, algo_class=None):
self.torch_compile_learner_dynamo_backend = (
"aot_eager" if sys.platform == "darwin" else "inductor"
)
self.torch_compile_learner_dynamo_mode = "reduce-overhead"
self.torch_compile_learner_dynamo_mode = None
self.torch_compile_worker = False
self.torch_compile_worker_dynamo_backend = (
"aot_eager" if sys.platform == "darwin" else "inductor"
"aot_eager" if sys.platform == "darwin" else "cudagraphs"
)
self.torch_compile_worker_dynamo_mode = "reduce-overhead"
self.torch_compile_worker_dynamo_mode = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None will make it so that for any chosen backend, we use the default mode.
codagraphs does not have a "reduce-overhead" mode so we need to choose None here.


# `self.environment()`
self.env = None
Expand Down Expand Up @@ -3035,7 +3035,6 @@ def get_torch_compile_learner_config(self):
)

return TorchCompileConfig(
compile_forward_train=self.torch_compile_learner,
torch_dynamo_backend=self.torch_compile_learner_dynamo_backend,
torch_dynamo_mode=self.torch_compile_learner_dynamo_mode,
)
Expand All @@ -3048,8 +3047,6 @@ def get_torch_compile_worker_config(self):
)

return TorchCompileConfig(
compile_forward_exploration=self.torch_compile_worker,
compile_forward_inference=self.torch_compile_worker,
torch_dynamo_backend=self.torch_compile_worker_dynamo_backend,
torch_dynamo_mode=self.torch_compile_worker_dynamo_mode,
)
Expand Down Expand Up @@ -3275,7 +3272,10 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi
)

if self.framework_str == "torch":
config.framework(torch_compile_cfg=self.get_torch_compile_learner_config())
config.framework(
torch_compile=self.torch_compile_learner,
torch_compile_cfg=self.get_torch_compile_learner_config(),
)
elif self.framework_str == "tf2":
config.framework(eager_tracing=self.eager_tracing)

Expand Down
23 changes: 23 additions & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,35 @@ class FrameworkHyperparameters:
This is useful for speeding up the training loop. However, it is not
compatible with all tf operations. For example, tf.print is not supported
in tf.function.
torch_compile: Whether to use torch.compile() within the context of a given
learner.
what_to_compile: What to compile when using torch.compile(). Can be one of
["complete_update", "forward_train"].
If "complete_update", the update step of the learner will be compiled. This
includes the forward pass of the RLModule, the loss computation, and the
optimizer step.
If "forward_train", only the forward methods (and therein the
forward_train method) of the RLModule will be compiled.
Either of the two may lead to different performance gains in different
settings.
"complete_update" promises the highest performance gains, but may work
in some settings. By compiling only forward_train, you may already get
some speedups and avoid issues that arise from compiling the entire update.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases, there are slight performance differences when compiling forward train vs the complete update.
Until we have explored this and know if we can eliminate one or the other possibility, we can use this switch to choose.

troch_compile_config: The TorchCompileConfig to use for compiling the RL
Module in Torch.
"""

eager_tracing: bool = False
torch_compile: bool = False
what_to_compile: str = "complete_update"
torch_compile_cfg: Optional["TorchCompileConfig"] = None

def validate(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to expose these parameters to the top algorithm_config. Right now what_to_compile is not surfacing up in algorithm config.

if self.what_to_compile not in ["complete_update", "forward_train"]:
raise ValueError(
"what_to_compile must be one of ['complete_update', 'forward_train']."
)


@dataclass
class LearnerHyperparameters:
Expand Down Expand Up @@ -314,6 +336,7 @@ def __init__(
self._framework_hyperparameters = (
framework_hyperparameters or FrameworkHyperparameters()
)
self._framework_hyperparameters.validate()

# whether self.build has already been called
self._is_built = False
Expand Down
6 changes: 6 additions & 0 deletions rllib/core/learner/learner_group_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, cls: Type[LearnerGroup] = None) -> None:

# `self.framework()`
self.eager_tracing = False
self.torch_compile = False
self.torch_compile_cfg = None

def validate(self) -> None:
Expand Down Expand Up @@ -85,6 +86,7 @@ def build(self) -> LearnerGroup:
framework_hps = FrameworkHyperparameters(
eager_tracing=self.eager_tracing,
torch_compile_cfg=self.torch_compile_cfg,
torch_compile=self.torch_compile,
)

learner_spec = LearnerSpec(
Expand All @@ -100,12 +102,16 @@ def build(self) -> LearnerGroup:
def framework(
self,
eager_tracing: Optional[bool] = NotProvided,
torch_compile: Optional[bool] = NotProvided,
torch_compile_cfg: Optional["TorchCompileConfig"] = NotProvided,
) -> "LearnerGroupConfig":

if eager_tracing is not NotProvided:
self.eager_tracing = eager_tracing

if torch_compile is not NotProvided:
self.torch_compile = torch_compile

if torch_compile_cfg is not NotProvided:
self.torch_compile_cfg = torch_compile_cfg

Expand Down
129 changes: 129 additions & 0 deletions rllib/core/learner/torch/tests/test_torch_learner_compile.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test will get skipped on CI right? Can we add that as a comment / TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be clear from the @unittest.skipIf() above. I don't think we should add this as a comment because the comment can become out of date quickly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair

Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import itertools
import unittest

import gymnasium as gym

import ray
from ray.rllib.core.learner.learner import FrameworkHyperparameters
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.models.tests.test_base_models import _dynamo_is_available
from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig
from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner
from ray.rllib.core.testing.utils import get_learner
from ray.rllib.core.testing.utils import get_module_spec
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.test_utils import get_cartpole_dataset_reader


def _get_learner(learning_rate: float = 1e-3) -> Learner:
env = gym.make("CartPole-v1")
# adding learning rate as a configurable parameter to avoid hardcoding it
# and information leakage across tests that rely on knowing the LR value
# that is used in the learner.
learner = get_learner("torch", env, learning_rate=learning_rate)
learner.build()

return learner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify get_learner in rllib.core.testing.utils and use it here?



class TestLearner(unittest.TestCase):
@classmethod
def setUp(cls) -> None:
ray.init()

@classmethod
def tearDown(cls) -> None:
ray.shutdown()

@unittest.skipIf(not _dynamo_is_available(), "torch._dynamo not available")
def test_torch_compile(self):
"""Test if torch.compile() can be applied and used on the learner.

Also tests if we can update with the compiled update method without errors.
"""

env = gym.make("CartPole-v1")
is_multi_agents = [False, True]
what_to_compiles = ["complete_update", "forward_train"]

for is_multi_agent, what_to_compile in itertools.product(
is_multi_agents, what_to_compiles
):
framework_hps = FrameworkHyperparameters(
torch_compile=True,
torch_compile_cfg=TorchCompileConfig(),
what_to_compile=what_to_compile,
)
spec = get_module_spec(
framework="torch", env=env, is_multi_agent=is_multi_agent
)
learner = BCTorchLearner(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you not using get_learner()?

module_spec=spec,
framework_hyperparameters=framework_hps,
)
learner.build()

reader = get_cartpole_dataset_reader(batch_size=512)

for iter_i in range(10):
batch = reader.next()
learner.update(batch.as_multi_agent())

spec = get_module_spec(framework="torch", env=env, is_multi_agent=False)
learner.add_module(module_id="another_module", module_spec=spec)

for iter_i in range(10):
batch = MultiAgentBatch(
{"another_module": reader.next(), "default_policy": reader.next()},
0,
)
Comment on lines +70 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you really want to call reader.next() twice per iteration? is this intentional? you can obtain the batch once and use it in two places.

learner.update(batch)

learner.remove_module(module_id="another_module")

@unittest.skipIf(not _dynamo_is_available(), "torch._dynamo not available")
def test_torch_compile_no_breaks(self):
"""Tests if torch.compile() does encounter too many breaks.

torch.compile() should ideally not encounter any breaks when compiling the
update method of the learner. This method tests if we encounter only a given
number of breaks.
"""

env = gym.make("CartPole-v1")
framework_hps = FrameworkHyperparameters(
torch_compile=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_compile shouldn't be true here? can you explain?

torch_compile_cfg=TorchCompileConfig(),
)

spec = get_module_spec(framework="torch", env=env)
learner = BCTorchLearner(
module_spec=spec,
framework_hyperparameters=framework_hps,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use get_learner

learner.build()

import torch._dynamo as dynamo

reader = get_cartpole_dataset_reader(batch_size=512)

batch = reader.next().as_multi_agent()
batch = learner._convert_batch_type(batch)

# This is a helper method of dynamo to analyze where breaks occur.
dynamo_explanation = dynamo.explain(learner._update, batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like on the newer version they output a dataclass. You either want to use that and make sure we pin the torch version in your tests (and skip or error) even if the version is not consistent.

https://sourcegraph.com/github.com/pytorch/pytorch@e9674d146ce424d3ea44f8b2ffd9e9f92dfa15f7/-/blob/torch/_dynamo/backends/debugging.py

If you use the tuple return version please don't hard code indices. instead use the named tuple assignment

gm, graphs, op_count, ops_per_graph, break_reasons = dynamo.explain(learner._update, batch)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saw that the tuples change between versions, too.
2.0.1 adds a tuple over 2.0.0. So every new torch version is different here atm.
I'll leave a comment that explains this so that we can fit this once our CI offers some version >= 2.0.0

print(dynamo_explanation[5])

# There should be only one break reason - `return_value` - since inputs and
# outputs are not checked
break_reasons_list = dynamo_explanation[4]

# TODO(Artur): Attempt bringing breaks down to 1. (This may not be possible)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is len always gonna be three?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it is.
We'll want to fit this test to a couple of our RLModules in the future so that we don't introduce breaks on accident because these introduce silent regressions. So for each module we'll have to put a number of graph breaks that we expect over time.
For now, the test only includes the BC learner and should be stable with these three graph breaks.

self.assertEquals(len(break_reasons_list), 3)


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
Loading