Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Compile update logic on learner and use cudagraphs #35759

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import ModuleID, SingleAgentRLModuleSpec
from ray.rllib.env.env_context import EnvContext
from ray.rllib.core.learner.learner import TorchCompileWhatToCompile
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.wrappers.atari_wrappers import is_atari
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
Expand Down Expand Up @@ -282,6 +283,9 @@ def __init__(self, algo_class=None):
}
# Torch compile settings
self.torch_compile_learner = False
self.torch_compile_learner_what_to_compile = (
TorchCompileWhatToCompile.forward_train
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enums should usually be all CAPS. e.g. FORWARD_TRAIN

)
self.torch_compile_learner_dynamo_backend = (
"aot_eager" if sys.platform == "darwin" else "inductor"
)
Expand Down Expand Up @@ -1228,6 +1232,7 @@ def framework(
tf_session_args: Optional[Dict[str, Any]] = NotProvided,
local_tf_session_args: Optional[Dict[str, Any]] = NotProvided,
torch_compile_learner: Optional[bool] = NotProvided,
torch_compile_learner_what_to_compile: Optional[str] = NotProvided,
torch_compile_learner_dynamo_mode: Optional[str] = NotProvided,
torch_compile_learner_dynamo_backend: Optional[str] = NotProvided,
torch_compile_worker: Optional[bool] = NotProvided,
Expand All @@ -1254,8 +1259,22 @@ def framework(
local_tf_session_args: Override the following tf session args on the local
worker
torch_compile_learner: If True, forward_train methods on TorchRLModules
on the learner are compiled. If not specified, the default is to compile
forward train on the learner.
on the learner are compiled. If not specified, the default is to compile
forward train on the learner.
torch_compile_learner_what_to_compile: A string specifying what to
compile on the learner side if torch_compile_learner is True.
This can be one of the following:
- TorchCompileWhatToCompile.complete_update: Compile the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just say see TorchCompileWhatToCompile for available options. This way you don't have to duplicate docstring if things change later down the line.

forward_train method, the loss calculation and the optimizer step
together on the TorchLearner.
- TorchCompileWhatToCompile.forward_train: Compile only forward train.
Note:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ..Note:: directives and check the rendered documentation to see if the formating is correctly rendered.

- torch.compiled code can become slow on graph breaks or even raise
errors on unsupported operations. Empirically, compiling
`forward_train` should introduce little graph breaks, raise no
errors but result in a speedup comparable to compiling the
complete update.
- Using `complete_update` is experimental and may result in errors.
torch_compile_learner_dynamo_backend: The torch dynamo backend to use on
the learner.
torch_compile_learner_dynamo_mode: The torch dynamo mode to use on the
Expand Down Expand Up @@ -1297,6 +1316,10 @@ def framework(
)
if torch_compile_learner_dynamo_mode is not NotProvided:
self.torch_compile_learner_dynamo_mode = torch_compile_learner_dynamo_mode
if torch_compile_learner_what_to_compile is not NotProvided:
self.torch_compile_learner_what_to_compile = (
torch_compile_learner_what_to_compile
)
if torch_compile_worker is not NotProvided:
self.torch_compile_worker = torch_compile_worker
if torch_compile_worker_dynamo_backend is not NotProvided:
Expand Down Expand Up @@ -3336,6 +3359,7 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi
config.framework(
torch_compile=self.torch_compile_learner,
torch_compile_cfg=self.get_torch_compile_learner_config(),
torch_compile_what_to_compile=self.torch_compile_learner_what_to_compile, # noqa: E501
)
elif self.framework_str == "tf2":
config.framework(eager_tracing=self.eager_tracing)
Expand Down
43 changes: 34 additions & 9 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import pathlib
from collections import defaultdict
from enum import Enum
from dataclasses import dataclass, field
from typing import (
Any,
Expand Down Expand Up @@ -82,6 +83,18 @@
LEARNER_RESULTS_CURR_LR_KEY = "curr_lr"


class TorchCompileWhatToCompile(str, Enum):
"""Enumerates schemes of what parts of the TorchLearner can be compiled."""

# Compile the entire update step of the learner.
# This includes the forward pass of the RLModule, the loss computation, and the
# optimizer step.
complete_update = "complete_update"
# Only compile the forward methods (and therein the forward_train method) of the
# RLModule.
forward_train = "forward_train"


@dataclass
class FrameworkHyperparameters:
"""The framework specific hyper-parameters.
Expand All @@ -95,15 +108,16 @@ class FrameworkHyperparameters:
torch_compile: Whether to use torch.compile() within the context of a given
learner.
what_to_compile: What to compile when using torch.compile(). Can be one of
["complete_update", "forward_train"].
If "complete_update", the update step of the learner will be compiled. This
[TorchCompileWhatToCompile.complete_update,
TorchCompileWhatToCompile.forward_train].
If `complete_update`, the update step of the learner will be compiled. This
includes the forward pass of the RLModule, the loss computation, and the
optimizer step.
If "forward_train", only the forward methods (and therein the
If `forward_train`, only the forward methods (and therein the
forward_train method) of the RLModule will be compiled.
Either of the two may lead to different performance gains in different
settings.
"complete_update" promises the highest performance gains, but may work
`complete_update` promises the highest performance gains, but may not work
in some settings. By compiling only forward_train, you may already get
some speedups and avoid issues that arise from compiling the entire update.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases, there are slight performance differences when compiling forward train vs the complete update.
Until we have explored this and know if we can eliminate one or the other possibility, we can use this switch to choose.

troch_compile_config: The TorchCompileConfig to use for compiling the RL
Expand All @@ -112,14 +126,25 @@ class FrameworkHyperparameters:

eager_tracing: bool = False
torch_compile: bool = False
what_to_compile: str = "complete_update"
what_to_compile: str = TorchCompileWhatToCompile.forward_train
torch_compile_cfg: Optional["TorchCompileConfig"] = None

def validate(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to expose these parameters to the top algorithm_config. Right now what_to_compile is not surfacing up in algorithm config.

if self.what_to_compile not in ["complete_update", "forward_train"]:
raise ValueError(
"what_to_compile must be one of ['complete_update', 'forward_train']."
)
if self.torch_compile:
if self.what_to_compile not in [
TorchCompileWhatToCompile.forward_train,
TorchCompileWhatToCompile.complete_update,
]:
raise ValueError(
f"what_to_compile must be one of ["
f"TorchCompileWhatToCompile.forward_train, "
f"TorchCompileWhatToCompile.complete_update] but is"
f" {self.what_to_compile}"
)
if self.torch_compile_cfg is None:
raise ValueError(
"torch_compile_cfg must be set when torch_compile is True."
)


@dataclass
Expand Down
6 changes: 6 additions & 0 deletions rllib/core/learner/learner_group_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, cls: Type[LearnerGroup] = None) -> None:
self.eager_tracing = False
self.torch_compile = False
self.torch_compile_cfg = None
self.torch_compile_what_to_compile = None

def validate(self) -> None:

Expand Down Expand Up @@ -87,6 +88,7 @@ def build(self) -> LearnerGroup:
eager_tracing=self.eager_tracing,
torch_compile_cfg=self.torch_compile_cfg,
torch_compile=self.torch_compile,
what_to_compile=self.torch_compile_what_to_compile,
)

learner_spec = LearnerSpec(
Expand All @@ -104,6 +106,7 @@ def framework(
eager_tracing: Optional[bool] = NotProvided,
torch_compile: Optional[bool] = NotProvided,
torch_compile_cfg: Optional["TorchCompileConfig"] = NotProvided,
torch_compile_what_to_compile: Optional[str] = NotProvided,
) -> "LearnerGroupConfig":

if eager_tracing is not NotProvided:
Expand All @@ -115,6 +118,9 @@ def framework(
if torch_compile_cfg is not NotProvided:
self.torch_compile_cfg = torch_compile_cfg

if torch_compile_what_to_compile is not NotProvided:
self.torch_compile_what_to_compile = torch_compile_what_to_compile

return self

def module(
Expand Down
34 changes: 27 additions & 7 deletions rllib/core/learner/tests/test_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.testing.testing_learner import BaseTestingLearnerHyperparameters
from ray.rllib.core.testing.utils import get_learner, get_module_spec
from ray.rllib.core.learner.learner import FrameworkHyperparameters
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -36,7 +37,10 @@ def tearDown(cls) -> None:
def test_end_to_end_update(self):

for fw in framework_iterator(frameworks=("torch", "tf2")):
learner = get_learner(framework=fw, eager_tracing=True, env=self.ENV)
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw, framework_hps=framework_hps, env=self.ENV
)
reader = get_cartpole_dataset_reader(batch_size=512)

min_loss = float("inf")
Expand All @@ -60,7 +64,10 @@ def test_compute_gradients(self):
the weights is all ones.
"""
for fw in framework_iterator(frameworks=("torch", "tf2")):
learner = get_learner(framework=fw, eager_tracing=True, env=self.ENV)
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw, framework_hps=framework_hps, env=self.ENV
)

params = learner.get_parameters(learner.module[DEFAULT_POLICY_ID])

Expand Down Expand Up @@ -92,9 +99,10 @@ def test_postprocess_gradients(self):
grad_clip_by="value",
)

framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw,
eager_tracing=True,
framework_hps=framework_hps,
env=self.ENV,
learner_hps=hps,
)
Expand All @@ -117,9 +125,10 @@ def test_postprocess_gradients(self):
# Clip by norm.
hps.grad_clip = 1.0
hps.grad_clip_by = "norm"
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw,
eager_tracing=True,
framework_hps=framework_hps,
env=self.ENV,
learner_hps=hps,
)
Expand All @@ -142,9 +151,10 @@ def test_postprocess_gradients(self):
# Clip by global norm.
hps.grad_clip = 5.0
hps.grad_clip_by = "global_norm"
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw,
eager_tracing=True,
framework_hps=framework_hps,
env=self.ENV,
learner_hps=hps,
)
Expand Down Expand Up @@ -177,8 +187,10 @@ def test_apply_gradients(self):
"""

for fw in framework_iterator(frameworks=("torch", "tf2")):
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw,
framework_hps=framework_hps,
env=self.ENV,
learner_hps=BaseTestingLearnerHyperparameters(learning_rate=0.0003),
)
Expand Down Expand Up @@ -214,8 +226,10 @@ def test_add_remove_module(self):
all variables the updated parameters follow the SGD update rule.
"""
for fw in framework_iterator(frameworks=("torch", "tf2")):
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw,
framework_hps=framework_hps,
env=self.ENV,
learner_hps=BaseTestingLearnerHyperparameters(learning_rate=0.0003),
)
Expand Down Expand Up @@ -260,11 +274,17 @@ def test_save_load_state(self):
"""Tests, whether a Learner's state is properly saved and restored."""
for fw in framework_iterator(frameworks=("torch", "tf2")):
# Get a Learner instance for the framework and env.
learner1 = get_learner(framework=fw, env=self.ENV)
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner1 = get_learner(
framework=fw, framework_hps=framework_hps, env=self.ENV
)
with tempfile.TemporaryDirectory() as tmpdir:
learner1.save_state(tmpdir)

learner2 = get_learner(framework=fw, env=self.ENV)
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner2 = get_learner(
framework=fw, framework_hps=framework_hps, env=self.ENV
)
learner2.load_state(tmpdir)
self._check_learner_states(fw, learner1, learner2)

Expand Down
4 changes: 3 additions & 1 deletion rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import ray
from ray.rllib.algorithms.ppo.tests.test_ppo_learner import FAKE_BATCH
from ray.rllib.core.learner.learner import FrameworkHyperparameters
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
SampleBatch,
Expand Down Expand Up @@ -63,7 +64,8 @@ def local_training_helper(self, fw, scaling_mode) -> None:
env = gym.make("CartPole-v1")
scaling_config = LOCAL_SCALING_CONFIGS[scaling_mode]
learner_group = get_learner_group(fw, env, scaling_config, eager_tracing=True)
local_learner = get_learner(framework=fw, env=env)
framework_hps = FrameworkHyperparameters(eager_tracing=True)
local_learner = get_learner(framework=fw, framework_hps=framework_hps, env=env)
local_learner.build()

# make the state of the learner and the local learner_group identical
Expand Down
Loading