Skip to content

Commit

Permalink
Re-enable cuda graphs in training modes. (NVIDIA#9338) (NVIDIA#9343)
Browse files Browse the repository at this point in the history
* Re-enable cuda graphs in training modes.

"global" capture mode was sporadically crashing because of pinning
host memory in other threads spawned by the data loader when
num_workers > 0.

Add relevant changs to TDT cuda graphs decoding as well.

I didn't test the TDT change because I'm not sure how. But it seems low risk.

* Apply isort and black reformatting

---------

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: galv <galv@users.noreply.github.com>
Co-authored-by: Daniel Galvez <galv@users.noreply.github.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>
  • Loading branch information
3 people authored and BoxiangW committed Jun 5, 2024
1 parent 5f2573c commit 0d20bf9
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 26 deletions.
4 changes: 1 addition & 3 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ class TranscriptionConfig:

# Decoding strategy for RNNT models
# enable CUDA graphs for transcription
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(
fused_batch_size=-1, greedy=GreedyBatchedRNNTInferConfig(use_cuda_graph_decoder=True)
)
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)

# Decoding strategy for AED models
multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig()
Expand Down
6 changes: 2 additions & 4 deletions examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ class ParallelTranscriptionConfig:
use_cer: bool = False

# decoding strategy for RNNT models
# enable CUDA graphs for transcription
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(
fused_batch_size=-1, greedy=GreedyBatchedRNNTInferConfig(use_cuda_graph_decoder=True)
)
# Double check whether fused_batch_size=-1 is right
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)

# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models
decoder_type: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

def create_outer_for_loop_kernel():
"""
Creates a kernel that evaluates whether or not to enter the for loop body.
Creates a kernel that evaluates whether or not to enter the for loop body.
Effectively substitutes for `for time_idx in range(trip_count)`
such that that for loop can run on a GPU.
"""
Expand Down Expand Up @@ -171,8 +171,10 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len

# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
stream_for_graph = torch.cuda.Stream(self.device)
with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph(
self.graph, stream=stream_for_graph
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
# This is failing...
self.f = torch.zeros(
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
)
else:
self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer(
Expand All @@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
include_duration_confidence=self.tdt_include_duration_confidence,
confidence_method_cfg=self.confidence_method_cfg,
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
)

else:
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def __init__(
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = True,
use_cuda_graph_decoder: bool = False,
use_cuda_graph_decoder: bool = True,
):
super().__init__(
decoder_model=decoder_model,
Expand Down Expand Up @@ -2360,7 +2360,7 @@ class GreedyBatchedRNNTInferConfig:
tdt_include_duration_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
loop_labels: bool = True
use_cuda_graph_decoder: bool = False
use_cuda_graph_decoder: bool = True

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
Expand Down Expand Up @@ -2712,7 +2712,7 @@ def __init__(
preserve_frame_confidence: bool = False,
include_duration_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
use_cuda_graph_decoder: bool = False,
use_cuda_graph_decoder: bool = True,
):
super().__init__(
decoder_model=decoder_model,
Expand Down
19 changes: 13 additions & 6 deletions nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,42 +630,49 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph),
torch.cuda.graph(
self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._before_outer_loop()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph),
torch.cuda.graph(
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._before_inner_loop_get_decoder_output()
self._before_inner_loop_get_joint_output()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph),
torch.cuda.graph(
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._inner_loop_code()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph),
torch.cuda.graph(
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._after_inner_loop()

def _full_graph_compile(self):
"""Compile full graph for decoding"""
# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
stream_for_graph = torch.cuda.Stream(self.state.device)
stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device))
self.full_graph = torch.cuda.CUDAGraph()
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.full_graph, stream=stream_for_graph),
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

Expand Down
18 changes: 13 additions & 5 deletions nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,29 +691,37 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph),
torch.cuda.graph(
self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._before_outer_loop()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph),
torch.cuda.graph(
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._before_inner_loop_get_decoder_output()
self._before_inner_loop_get_joint_output()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph),
torch.cuda.graph(
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._inner_loop_code()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph),
torch.cuda.graph(
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._after_inner_loop()

Expand All @@ -726,7 +734,7 @@ def _full_graph_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.full_graph, stream=stream_for_graph),
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

Expand Down

0 comments on commit 0d20bf9

Please sign in to comment.