Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper module around TRT + pytorch subgraphs #3270

Merged
merged 17 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngi
void setup_input_tensors(
std::vector<at::Tensor> inputs,
c10::intrusive_ptr<TRTEngine> compiled_engine,
bool cudagraphs_enabled,
bool need_cudagraphs_record) {
// this is a buffer to store shape tensor input addresses throughout the runtime scope
std::list<std::vector<int64_t>> inputShapeTensorValues;
Expand Down Expand Up @@ -127,7 +128,7 @@ void setup_input_tensors(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
compiled_engine->input_buffers[i] = input_cpu;
}
Expand All @@ -147,7 +148,7 @@ void setup_input_tensors(
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
TORCHTRT_CHECK(
Expand Down Expand Up @@ -201,17 +202,17 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
LOG_INFO("" << log_info);
compiled_engine->cudagraph.enable_debug_mode();
}

bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
bool shape_changed = _validate_shapes(inputs, compiled_engine);

// Whether cudagraphs needs to record the graph on this pass
auto result = compiled_engine->runtime_states.set_runtime_states(
CUDAGRAPHS_MODE, compiled_engine->use_pre_allocated_outputs, shape_changed);
cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed);

bool need_cudagraphs_record = std::get<0>(result);
bool can_use_pre_allocated_outputs = std::get<1>(result);

if (!CUDAGRAPHS_MODE || shape_changed) {
if (!cudagraphs_enabled || shape_changed) {
compiled_engine->cudagraph.reset();
}

Expand Down Expand Up @@ -273,8 +274,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
}

setup_input_tensors(inputs, compiled_engine, need_cudagraphs_record);

setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
std::vector<char const*> names(io_size);
Expand Down Expand Up @@ -306,7 +306,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
}

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
Expand Down Expand Up @@ -346,7 +346,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
caller_exec_complete.record(compiled_engine->caller_stream);
caller_exec_complete.block(compiled_engine->engine_stream);

if (!CUDAGRAPHS_MODE) {
if (!cudagraphs_enabled) {
// Direct execution uses the caller buffers directly
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
} else {
Expand Down Expand Up @@ -377,7 +377,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
// If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) {
outputs[o].copy_(compiled_engine->output_buffers[o], false);
Expand Down
6 changes: 4 additions & 2 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; });
m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; });
m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; });
m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void {
CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode);
});
m.def("set_logging_level", [](int64_t level) -> void {
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
});
Expand Down
6 changes: 3 additions & 3 deletions core/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace core {
namespace runtime {

bool MULTI_DEVICE_SAFE_MODE = false;
bool CUDAGRAPHS_MODE = false;
CudaGraphsMode CUDAGRAPHS_MODE = STANDARD;

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
Expand Down Expand Up @@ -130,11 +130,11 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
}

bool get_cudagraphs_mode() {
CudaGraphsMode get_cudagraphs_mode() {
return CUDAGRAPHS_MODE;
}

void set_cudagraphs_mode(bool cudagraphs_mode) {
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode) {
CUDAGRAPHS_MODE = cudagraphs_mode;
}

Expand Down
13 changes: 10 additions & 3 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ namespace runtime {
using EngineID = int64_t;
const std::string ABI_VERSION = "6";
extern bool MULTI_DEVICE_SAFE_MODE;
extern bool CUDAGRAPHS_MODE;

typedef enum {
STANDARD = 0,
SUBGRAPH_CUDAGRAPHS,
WHOLE_GRAPH_CUDAGRAPHS,
} CudaGraphsMode;

extern CudaGraphsMode CUDAGRAPHS_MODE;

typedef enum {
ABI_TARGET_IDX = 0,
Expand Down Expand Up @@ -51,9 +58,9 @@ bool get_multi_device_safe_mode();

void set_multi_device_safe_mode(bool multi_device_safe_mode);

bool get_cudagraphs_mode();
CudaGraphsMode get_cudagraphs_mode();

void set_cudagraphs_mode(bool cudagraphs_mode);
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode);

class DeviceList {
using DeviceMap = std::unordered_map<int, RTDevice>;
Expand Down
2 changes: 1 addition & 1 deletion docsrc/user_guide/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Cudagraphs can accelerate certain models by reducing kernel overheads, as docume
torch_tensorrt.runtime.set_cudagraphs_mode(False)

# Enables Cudagraphs Mode, then resets the mode to its prior setting
with torch_tensorrt.runtime.enable_cudagraphs():
with torch_tensorrt.runtime.enable_cudagraphs(trt_module):
...

In the current implementation, use of a new input shape (for instance in dynamic shape
Expand Down
56 changes: 49 additions & 7 deletions examples/dynamo/torch_export_cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
import torchvision.models as models

import torch_tensorrt
import torchvision.models as models

# %%
# Compilation with `torch_tensorrt.compile` Using Default Settings
Expand Down Expand Up @@ -47,8 +46,8 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# We can enable the cudagraphs API with a context manager
with torch_tensorrt.runtime.enable_cudagraphs():
out_trt = opt(inputs)
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
out_trt = cudagraphs_module(inputs)

# Alternatively, we can set the cudagraphs mode for the session
torch_tensorrt.runtime.set_cudagraphs_mode(True)
Expand All @@ -64,6 +63,49 @@
inputs_2 = torch.randn((8, 3, 224, 224)).cuda()
inputs_3 = torch.randn((4, 3, 224, 224)).cuda()

with torch_tensorrt.runtime.enable_cudagraphs():
out_trt_2 = opt(inputs_2)
out_trt_3 = opt(inputs_3)
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
out_trt_2 = cudagraphs_module(inputs_2)
out_trt_3 = cudagraphs_module(inputs_3)

# %%
# Cuda graphs with module that contains graph breaks
# ----------------------------------
#
# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional
# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous
# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced
# kernel launch overhead and improved execution efficiency, may be diminished.
# Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs
# that can be executed efficiently, even in the presence of graph breaks.
# If TensorRT module has graph breaks, CUDA Graph context manager returns a wrapped_module. This module captures entire
# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads
# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the
# module is executed several times. This warm-up ensures that memory allocations and initializations are not
# recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance.


class SampleModel(torch.nn.Module):
def forward(self, x):
return torch.relu((x + 2) * 0.5)


model = SampleModel().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

# The 'torch_executed_ops' compiler option is used in this example to intentionally introduce graph breaks within the module.
# Note: The Dynamo backend is required for the CUDA Graph context manager to handle modules in an Ahead-Of-Time (AOT) manner.
opt_with_graph_break = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[input],
min_block_size=1,
pass_through_build_failures=True,
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
)

# %%
# If module has graph breaks, whole submodules are recorded and replayed by cuda graphs
with torch_tensorrt.runtime.enable_cudagraphs(
opt_with_graph_break
) as cudagraphs_module:
cudagraphs_module(input)
7 changes: 6 additions & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
CudaGraphsTorchTensorRTModule,
)
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
Expand Down Expand Up @@ -586,14 +589,16 @@ def save(
Save the model to disk in the specified output format.

Arguments:
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
"""
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
if arg_inputs is not None and not all(
Expand Down
Loading
Loading