diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index d244de33c4f4..4f5f7364171e 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -143,7 +143,7 @@ def __init__( dropout_rate=dropout_rate, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa, - use_pytorch_sdpa_backends=use_pytorch_sdpa_backends + use_pytorch_sdpa_backends=use_pytorch_sdpa_backends, ) self.pre_norm = nn.LayerNorm(n_feat) diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 1d9e2c6a67e0..3cde743da833 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -167,7 +167,9 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None): dropout_rate = self.dropout_rate if self.training else 0 with sdpa_kernel(self.use_pytorch_sdpa_backends): - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate) + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=dropout_rate + ) # this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version if mask is not None: