Skip to content

Commit

Permalink
[RLlib] Compile update logic on learner and use cudagraphs (#35759)
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
  • Loading branch information
ArturNiederfahrenhorst authored Jun 21, 2023
1 parent 827ab91 commit 2a12cf5
Show file tree
Hide file tree
Showing 15 changed files with 436 additions and 108 deletions.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2091,6 +2091,13 @@ py_test(
srcs = ["core/learner/tests/test_learner.py"]
)

py_test(
name = "test_torch_learner_compile",
tags = ["team:rllib", "core", "ray_data"],
size = "medium",
srcs = ["core/learner/torch/tests/test_torch_learner_compile.py"]
)

py_test(
name ="tests/test_algorithm_save_load_checkpoint_learner",
tags = ["team:rllib", "core"],
Expand Down
34 changes: 25 additions & 9 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import ModuleID, SingleAgentRLModuleSpec
from ray.rllib.env.env_context import EnvContext
from ray.rllib.core.learner.learner import TorchCompileWhatToCompile
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.wrappers.atari_wrappers import is_atari
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
Expand Down Expand Up @@ -279,15 +280,20 @@ def __init__(self, algo_class=None):
}
# Torch compile settings
self.torch_compile_learner = False
self.torch_compile_learner_what_to_compile = (
TorchCompileWhatToCompile.FORWARD_TRAIN
)
# AOT Eager is a dummy backend and will not result in speedups
self.torch_compile_learner_dynamo_backend = (
"aot_eager" if sys.platform == "darwin" else "inductor"
)
self.torch_compile_learner_dynamo_mode = "reduce-overhead"
self.torch_compile_learner_dynamo_mode = None
self.torch_compile_worker = False
# AOT Eager is a dummy backend and will not result in speedups
self.torch_compile_worker_dynamo_backend = (
"aot_eager" if sys.platform == "darwin" else "inductor"
"aot_eager" if sys.platform == "darwin" else "onnxrt"
)
self.torch_compile_worker_dynamo_mode = "reduce-overhead"
self.torch_compile_worker_dynamo_mode = None

# `self.environment()`
self.env = None
Expand Down Expand Up @@ -1197,6 +1203,7 @@ def framework(
tf_session_args: Optional[Dict[str, Any]] = NotProvided,
local_tf_session_args: Optional[Dict[str, Any]] = NotProvided,
torch_compile_learner: Optional[bool] = NotProvided,
torch_compile_learner_what_to_compile: Optional[str] = NotProvided,
torch_compile_learner_dynamo_mode: Optional[str] = NotProvided,
torch_compile_learner_dynamo_backend: Optional[str] = NotProvided,
torch_compile_worker: Optional[bool] = NotProvided,
Expand All @@ -1223,8 +1230,12 @@ def framework(
local_tf_session_args: Override the following tf session args on the local
worker
torch_compile_learner: If True, forward_train methods on TorchRLModules
on the learner are compiled. If not specified, the default is to compile
forward train on the learner.
on the learner are compiled. If not specified, the default is to compile
forward train on the learner.
torch_compile_learner_what_to_compile: A TorchCompileWhatToCompile
mode specifying what to compile on the learner side if
torch_compile_learner is True. See TorchCompileWhatToCompile for
details and advice on its usage.
torch_compile_learner_dynamo_backend: The torch dynamo backend to use on
the learner.
torch_compile_learner_dynamo_mode: The torch dynamo mode to use on the
Expand Down Expand Up @@ -1266,6 +1277,10 @@ def framework(
)
if torch_compile_learner_dynamo_mode is not NotProvided:
self.torch_compile_learner_dynamo_mode = torch_compile_learner_dynamo_mode
if torch_compile_learner_what_to_compile is not NotProvided:
self.torch_compile_learner_what_to_compile = (
torch_compile_learner_what_to_compile
)
if torch_compile_worker is not NotProvided:
self.torch_compile_worker = torch_compile_worker
if torch_compile_worker_dynamo_backend is not NotProvided:
Expand Down Expand Up @@ -3056,7 +3071,6 @@ def get_torch_compile_learner_config(self):
)

return TorchCompileConfig(
compile_forward_train=self.torch_compile_learner,
torch_dynamo_backend=self.torch_compile_learner_dynamo_backend,
torch_dynamo_mode=self.torch_compile_learner_dynamo_mode,
)
Expand All @@ -3069,8 +3083,6 @@ def get_torch_compile_worker_config(self):
)

return TorchCompileConfig(
compile_forward_exploration=self.torch_compile_worker,
compile_forward_inference=self.torch_compile_worker,
torch_dynamo_backend=self.torch_compile_worker_dynamo_backend,
torch_dynamo_mode=self.torch_compile_worker_dynamo_mode,
)
Expand Down Expand Up @@ -3341,7 +3353,11 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi
)

if self.framework_str == "torch":
config.framework(torch_compile_cfg=self.get_torch_compile_learner_config())
config.framework(
torch_compile=self.torch_compile_learner,
torch_compile_cfg=self.get_torch_compile_learner_config(),
torch_compile_what_to_compile=self.torch_compile_learner_what_to_compile, # noqa: E501
)
elif self.framework_str == "tf2":
config.framework(eager_tracing=self.eager_tracing)

Expand Down
60 changes: 60 additions & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import pathlib
from collections import defaultdict
from enum import Enum
from dataclasses import dataclass, field
from typing import (
Any,
Expand Down Expand Up @@ -82,6 +83,30 @@
LEARNER_RESULTS_CURR_LR_KEY = "curr_lr"


class TorchCompileWhatToCompile(str, Enum):
"""Enumerates schemes of what parts of the TorchLearner can be compiled.
This can be either the entire update step of the learner or only the forward
methods (and therein the forward_train method) of the RLModule.
.. note::
- torch.compiled code can become slow on graph breaks or even raise
errors on unsupported operations. Empirically, compiling
`forward_train` should introduce little graph breaks, raise no
errors but result in a speedup comparable to compiling the
complete update.
- Using `complete_update` is experimental and may result in errors.
"""

# Compile the entire update step of the learner.
# This includes the forward pass of the RLModule, the loss computation, and the
# optimizer step.
COMPLETE_UPDATE = "complete_update"
# Only compile the forward methods (and therein the forward_train method) of the
# RLModule.
FORWARD_TRAIN = "forward_train"


@dataclass
class FrameworkHyperparameters:
"""The framework specific hyper-parameters.
Expand All @@ -92,13 +117,47 @@ class FrameworkHyperparameters:
This is useful for speeding up the training loop. However, it is not
compatible with all tf operations. For example, tf.print is not supported
in tf.function.
torch_compile: Whether to use torch.compile() within the context of a given
learner.
what_to_compile: What to compile when using torch.compile(). Can be one of
[TorchCompileWhatToCompile.complete_update,
TorchCompileWhatToCompile.forward_train].
If `complete_update`, the update step of the learner will be compiled. This
includes the forward pass of the RLModule, the loss computation, and the
optimizer step.
If `forward_train`, only the forward methods (and therein the
forward_train method) of the RLModule will be compiled.
Either of the two may lead to different performance gains in different
settings.
`complete_update` promises the highest performance gains, but may not work
in some settings. By compiling only forward_train, you may already get
some speedups and avoid issues that arise from compiling the entire update.
troch_compile_config: The TorchCompileConfig to use for compiling the RL
Module in Torch.
"""

eager_tracing: bool = True
torch_compile: bool = False
what_to_compile: str = TorchCompileWhatToCompile.FORWARD_TRAIN
torch_compile_cfg: Optional["TorchCompileConfig"] = None

def validate(self):
if self.torch_compile:
if self.what_to_compile not in [
TorchCompileWhatToCompile.FORWARD_TRAIN,
TorchCompileWhatToCompile.COMPLETE_UPDATE,
]:
raise ValueError(
f"what_to_compile must be one of ["
f"TorchCompileWhatToCompile.forward_train, "
f"TorchCompileWhatToCompile.complete_update] but is"
f" {self.what_to_compile}"
)
if self.torch_compile_cfg is None:
raise ValueError(
"torch_compile_cfg must be set when torch_compile is True."
)


@dataclass
class LearnerHyperparameters:
Expand Down Expand Up @@ -314,6 +373,7 @@ def __init__(
self._framework_hyperparameters = (
framework_hyperparameters or FrameworkHyperparameters()
)
self._framework_hyperparameters.validate()

# whether self.build has already been called
self._is_built = False
Expand Down
12 changes: 12 additions & 0 deletions rllib/core/learner/learner_group_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def __init__(self, cls: Type[LearnerGroup] = None) -> None:

# `self.framework()`
self.eager_tracing = True
self.torch_compile = False
self.torch_compile_cfg = None
self.torch_compile_what_to_compile = None

def validate(self) -> None:

Expand Down Expand Up @@ -85,6 +87,8 @@ def build(self) -> LearnerGroup:
framework_hps = FrameworkHyperparameters(
eager_tracing=self.eager_tracing,
torch_compile_cfg=self.torch_compile_cfg,
torch_compile=self.torch_compile,
what_to_compile=self.torch_compile_what_to_compile,
)

learner_spec = LearnerSpec(
Expand All @@ -100,15 +104,23 @@ def build(self) -> LearnerGroup:
def framework(
self,
eager_tracing: Optional[bool] = NotProvided,
torch_compile: Optional[bool] = NotProvided,
torch_compile_cfg: Optional["TorchCompileConfig"] = NotProvided,
torch_compile_what_to_compile: Optional[str] = NotProvided,
) -> "LearnerGroupConfig":

if eager_tracing is not NotProvided:
self.eager_tracing = eager_tracing

if torch_compile is not NotProvided:
self.torch_compile = torch_compile

if torch_compile_cfg is not NotProvided:
self.torch_compile_cfg = torch_compile_cfg

if torch_compile_what_to_compile is not NotProvided:
self.torch_compile_what_to_compile = torch_compile_what_to_compile

return self

def module(
Expand Down
18 changes: 15 additions & 3 deletions rllib/core/learner/tests/test_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.testing.testing_learner import BaseTestingLearnerHyperparameters
from ray.rllib.core.testing.utils import get_learner, get_module_spec
from ray.rllib.core.learner.learner import FrameworkHyperparameters
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -140,9 +141,10 @@ def test_postprocess_gradients(self):
# Clip by global norm.
hps.grad_clip = 5.0
hps.grad_clip_by = "global_norm"
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw,
eager_tracing=True,
framework_hps=framework_hps,
env=self.ENV,
learner_hps=hps,
)
Expand Down Expand Up @@ -175,8 +177,10 @@ def test_apply_gradients(self):
"""

for fw in framework_iterator(frameworks=("torch", "tf2")):
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw,
framework_hps=framework_hps,
env=self.ENV,
learner_hps=BaseTestingLearnerHyperparameters(learning_rate=0.0003),
)
Expand Down Expand Up @@ -212,8 +216,10 @@ def test_add_remove_module(self):
all variables the updated parameters follow the SGD update rule.
"""
for fw in framework_iterator(frameworks=("torch", "tf2")):
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner = get_learner(
framework=fw,
framework_hps=framework_hps,
env=self.ENV,
learner_hps=BaseTestingLearnerHyperparameters(learning_rate=0.0003),
)
Expand Down Expand Up @@ -258,11 +264,17 @@ def test_save_load_state(self):
"""Tests, whether a Learner's state is properly saved and restored."""
for fw in framework_iterator(frameworks=("torch", "tf2")):
# Get a Learner instance for the framework and env.
learner1 = get_learner(framework=fw, env=self.ENV)
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner1 = get_learner(
framework=fw, framework_hps=framework_hps, env=self.ENV
)
with tempfile.TemporaryDirectory() as tmpdir:
learner1.save_state(tmpdir)

learner2 = get_learner(framework=fw, env=self.ENV)
framework_hps = FrameworkHyperparameters(eager_tracing=True)
learner2 = get_learner(
framework=fw, framework_hps=framework_hps, env=self.ENV
)
learner2.load_state(tmpdir)
self._check_learner_states(fw, learner1, learner2)

Expand Down
4 changes: 3 additions & 1 deletion rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import ray
from ray.rllib.algorithms.ppo.tests.test_ppo_learner import FAKE_BATCH
from ray.rllib.core.learner.learner import FrameworkHyperparameters
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
SampleBatch,
Expand Down Expand Up @@ -63,7 +64,8 @@ def local_training_helper(self, fw, scaling_mode) -> None:
env = gym.make("CartPole-v1")
scaling_config = LOCAL_SCALING_CONFIGS[scaling_mode]
learner_group = get_learner_group(fw, env, scaling_config)
local_learner = get_learner(framework=fw, env=env)
framework_hps = FrameworkHyperparameters(eager_tracing=True)
local_learner = get_learner(framework=fw, framework_hps=framework_hps, env=env)
local_learner.build()

# make the state of the learner and the local learner_group identical
Expand Down
Loading

0 comments on commit 2a12cf5

Please sign in to comment.