Skip to content

Commit

Permalink
use_pytorch_sdpa parameter forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
kramarenko.gs committed Aug 12, 2024
1 parent 0de7115 commit edcc9bd
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ model:
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
use_pytorch_sdpa: true

# Convolution module's params
conv_kernel_size: 9
Expand Down
8 changes: 8 additions & 0 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin):
Defaults to 1.
global_attn_separate (bool): whether the q, k, v layers used for global tokens should be separate.
Defaults to False.
use_pytorch_sdpa (bool): use torch sdpa instead of manual attention
Defaults to True.
"""

Expand Down Expand Up @@ -295,6 +297,7 @@ def __init__(
global_tokens: int = 0,
global_tokens_spacing: int = 1,
global_attn_separate: bool = False,
use_pytorch_sdpa: bool = True,
):
super().__init__()
d_ff = d_model * ff_expansion_factor
Expand All @@ -309,6 +312,7 @@ def __init__(
self.global_tokens = global_tokens
self.global_attn_separate = global_attn_separate
self.global_tokens_spacing = global_tokens_spacing
self.use_pytorch_sdpa = use_pytorch_sdpa

# Setting up the att_context_size
(
Expand Down Expand Up @@ -430,6 +434,7 @@ def __init__(
pos_bias_v=pos_bias_v,
att_context_size=self.att_context_size,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
self.layers.append(layer)

Expand Down Expand Up @@ -1028,6 +1033,7 @@ def change_attention_model(
max_cache_len=att_context_size[0],
pos_bias_u=None,
pos_bias_v=None,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
elif self_attention_model == 'rel_pos_local_attn':
new_attn = RelPositionMultiHeadAttentionLongformer(
Expand All @@ -1038,13 +1044,15 @@ def change_attention_model(
att_context_size=att_context_size,
pos_bias_u=None,
pos_bias_v=None,
# use_pytorch_sdpa=self.use_pytorch_sdpa,
)
elif self_attention_model == 'abs_pos':
new_attn = MultiHeadAttention(
n_head=self._cfg.n_heads,
n_feat=self._cfg.d_model,
dropout_rate=self._cfg.dropout_att,
max_cache_len=att_context_size[0],
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
else:
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/asr/parts/submodules/conformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ def __init__(
pos_bias_v=None,
att_context_size=[-1, -1],
use_bias=True,
use_pytorch_sdpa=True,
):
super(ConformerLayer, self).__init__()

self.use_pytorch_sdpa = use_pytorch_sdpa
self.self_attention_model = self_attention_model
self.n_heads = n_heads
self.fc_factor = 0.5
Expand Down Expand Up @@ -111,6 +113,7 @@ def __init__(
pos_bias_v=pos_bias_v,
max_cache_len=MHA_max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
elif self_attention_model == 'rel_pos_local_attn':
self.self_attn = RelPositionMultiHeadAttentionLongformer(
Expand All @@ -133,6 +136,7 @@ def __init__(
dropout_rate=dropout_att,
max_cache_len=MHA_max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
else:
raise ValueError(
Expand Down

0 comments on commit edcc9bd

Please sign in to comment.