Skip to content

Commit

Permalink
Re-enable cuda graphs in training modes.
Browse files Browse the repository at this point in the history
"global" capture mode was sporadically crashing because of pinning
host memory in other threads spawned by the data loader when
num_workers > 0.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
  • Loading branch information
galv committed May 29, 2024
1 parent 525604f commit cec4f53
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 19 deletions.
4 changes: 1 addition & 3 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ class TranscriptionConfig:

# Decoding strategy for RNNT models
# enable CUDA graphs for transcription
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(
fused_batch_size=-1, greedy=GreedyBatchedRNNTInferConfig(use_cuda_graph_decoder=True)
)
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)

# Decoding strategy for AED models
multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig()
Expand Down
6 changes: 2 additions & 4 deletions examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ class ParallelTranscriptionConfig:
use_cer: bool = False

# decoding strategy for RNNT models
# enable CUDA graphs for transcription
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(
fused_batch_size=-1, greedy=GreedyBatchedRNNTInferConfig(use_cuda_graph_decoder=True)
)
# Double check whether fused_batch_size=-1 is right
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)

# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models
decoder_type: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
stream_for_graph = torch.cuda.Stream(self.device)
with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph(
self.graph, stream=stream_for_graph
self.graph, stream=stream_for_graph, capture_error_mode="thread_local"
):
# This is failing...
self.f = torch.zeros(
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
)
else:
self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer(
Expand All @@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
include_duration_confidence=self.tdt_include_duration_confidence,
confidence_method_cfg=self.confidence_method_cfg,
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
)

else:
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def __init__(
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = True,
use_cuda_graph_decoder: bool = False,
use_cuda_graph_decoder: bool = True,
):
super().__init__(
decoder_model=decoder_model,
Expand Down Expand Up @@ -2360,7 +2360,7 @@ class GreedyBatchedRNNTInferConfig:
tdt_include_duration_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
loop_labels: bool = True
use_cuda_graph_decoder: bool = False
use_cuda_graph_decoder: bool = True

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
Expand Down Expand Up @@ -2712,7 +2712,7 @@ def __init__(
preserve_frame_confidence: bool = False,
include_duration_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
use_cuda_graph_decoder: bool = False,
use_cuda_graph_decoder: bool = True,
):
super().__init__(
decoder_model=decoder_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,42 +630,41 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph),
torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph),
torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_inner_loop_get_decoder_output()
self._before_inner_loop_get_joint_output()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph),
torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._inner_loop_code()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph),
torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._after_inner_loop()

def _full_graph_compile(self):
"""Compile full graph for decoding"""
# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
stream_for_graph = torch.cuda.Stream(self.state.device)
stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device))
self.full_graph = torch.cuda.CUDAGraph()
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.full_graph, stream=stream_for_graph),
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

Expand Down

0 comments on commit cec4f53

Please sign in to comment.