diff --git a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml index 680d96e1afaf..c6fa75ddce1c 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml @@ -145,6 +145,7 @@ model: xscaling: true # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers pos_emb_max_len: 5000 + use_pytorch_sdpa: true # Convolution module's params conv_kernel_size: 9 diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 245404a7601c..ea887918b143 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -147,6 +147,8 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): Defaults to 1. global_attn_separate (bool): whether the q, k, v layers used for global tokens should be separate. Defaults to False. + use_pytorch_sdpa (bool): use torch sdpa instead of manual attention + Defaults to True. """ @@ -295,6 +297,7 @@ def __init__( global_tokens: int = 0, global_tokens_spacing: int = 1, global_attn_separate: bool = False, + use_pytorch_sdpa: bool = True, ): super().__init__() d_ff = d_model * ff_expansion_factor @@ -309,6 +312,7 @@ def __init__( self.global_tokens = global_tokens self.global_attn_separate = global_attn_separate self.global_tokens_spacing = global_tokens_spacing + self.use_pytorch_sdpa = use_pytorch_sdpa # Setting up the att_context_size ( @@ -430,6 +434,7 @@ def __init__( pos_bias_v=pos_bias_v, att_context_size=self.att_context_size, use_bias=use_bias, + use_pytorch_sdpa=self.use_pytorch_sdpa, ) self.layers.append(layer) @@ -1028,6 +1033,7 @@ def change_attention_model( max_cache_len=att_context_size[0], pos_bias_u=None, pos_bias_v=None, + use_pytorch_sdpa=self.use_pytorch_sdpa, ) elif self_attention_model == 'rel_pos_local_attn': new_attn = RelPositionMultiHeadAttentionLongformer( @@ -1038,6 +1044,7 @@ def change_attention_model( att_context_size=att_context_size, pos_bias_u=None, pos_bias_v=None, + # use_pytorch_sdpa=self.use_pytorch_sdpa, ) elif self_attention_model == 'abs_pos': new_attn = MultiHeadAttention( @@ -1045,6 +1052,7 @@ def change_attention_model( n_feat=self._cfg.d_model, dropout_rate=self._cfg.dropout_att, max_cache_len=att_context_size[0], + use_pytorch_sdpa=self.use_pytorch_sdpa, ) else: raise ValueError( diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index c2d897d63225..09bd369ffbbc 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -77,9 +77,11 @@ def __init__( pos_bias_v=None, att_context_size=[-1, -1], use_bias=True, + use_pytorch_sdpa=True, ): super(ConformerLayer, self).__init__() + self.use_pytorch_sdpa = use_pytorch_sdpa self.self_attention_model = self_attention_model self.n_heads = n_heads self.fc_factor = 0.5 @@ -111,6 +113,7 @@ def __init__( pos_bias_v=pos_bias_v, max_cache_len=MHA_max_cache_len, use_bias=use_bias, + use_pytorch_sdpa=self.use_pytorch_sdpa, ) elif self_attention_model == 'rel_pos_local_attn': self.self_attn = RelPositionMultiHeadAttentionLongformer( @@ -133,6 +136,7 @@ def __init__( dropout_rate=dropout_att, max_cache_len=MHA_max_cache_len, use_bias=use_bias, + use_pytorch_sdpa=self.use_pytorch_sdpa, ) else: raise ValueError(