From b5930a7e050ddaebcfdf0455cf09a00d67282752 Mon Sep 17 00:00:00 2001 From: WoodieDudy Date: Wed, 21 Aug 2024 15:55:54 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: WoodieDudy --- .../adapters/multi_head_attention_adapter_module.py | 12 ++++++++++-- .../asr/parts/submodules/multi_head_attention.py | 10 ++++------ .../asr/mixins/adapters/test_asr_adapter_modules.py | 4 ++-- tests/collections/asr/test_conformer_encoder.py | 6 +++++- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index 2bfdaacf9dad..908c23a1786e 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -136,7 +136,9 @@ def __init__( adapter_strategy: MHAResidualAddAdapterStrategy = None, use_pytorch_sdpa: bool = True, ): - super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa) + super().__init__( + n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa + ) self.pre_norm = nn.LayerNorm(n_feat) @@ -230,7 +232,13 @@ def __init__( use_pytorch_sdpa: bool = False, ): super().__init__( - n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, pos_bias_u=None, pos_bias_v=None, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa + n_head=n_head, + n_feat=n_feat, + dropout_rate=dropout_rate, + pos_bias_u=None, + pos_bias_v=None, + max_cache_len=0, + use_pytorch_sdpa=use_pytorch_sdpa, ) self.pre_norm = nn.LayerNorm(n_feat) diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index a6f7627a6f99..bdbb4dce951c 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -159,13 +159,11 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None): # add extra col for key and value to handle problem with nan after solfmax extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device) - k = torch.cat([k, extra_column], dim=-2) + k = torch.cat([k, extra_column], dim=-2) extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device) - v = torch.cat([v, extra_column], dim=-2) - out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=dropout_rate - ) + v = torch.cat([v, extra_column], dim=-2) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate) out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) out = self.linear_out(out) # (batch, time1, d_model) else: @@ -303,7 +301,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None): dropout_rate = self.dropout_rate if self.training else 0 # add extra col for key and value to handle problem with nan after solfmax extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device) - k = torch.cat([k, extra_column], dim=-2) + k = torch.cat([k, extra_column], dim=-2) extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device) v = torch.cat([v, extra_column], dim=-2) diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py index f15892707201..ad33a21262f3 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py @@ -204,7 +204,7 @@ def test_relmha_adapter_with_torch_sdpa(self): _, pos_emb = relpos_enc(x) out = adapter(x, x, x, att_mask, pos_emb) out_sdpa = adapter_torch_sdpa(x, x, x, att_mask, pos_emb) - assert torch.allclose(out_sdpa, out, atol=1e-5) + assert torch.allclose(out_sdpa, out, atol=1e-5) @pytest.mark.unit def test_mha_adapter_with_torch_sdpa(self): @@ -229,7 +229,7 @@ def test_mha_adapter_with_torch_sdpa(self): with torch.no_grad(): out = adapter(x, x, x, att_mask) out_sdpa = adapter_torch_sdpa(x, x, x, att_mask) - assert torch.allclose(out_sdpa, out, atol=1e-5) + assert torch.allclose(out_sdpa, out, atol=1e-5) @pytest.mark.unit def test_abspos_encoding_init(self): diff --git a/tests/collections/asr/test_conformer_encoder.py b/tests/collections/asr/test_conformer_encoder.py index 5c1db485eddd..18cb902d1408 100644 --- a/tests/collections/asr/test_conformer_encoder.py +++ b/tests/collections/asr/test_conformer_encoder.py @@ -74,7 +74,11 @@ def test_stochastic_depth_model_creation(self): for start_layer in [-1, 0, 5]: with pytest.raises(ValueError, match="stochastic_depth_start_layer has to be in"): ConformerEncoder( - feat_in=10, n_layers=n_layers, d_model=4, feat_out=8, stochastic_depth_start_layer=start_layer, + feat_in=10, + n_layers=n_layers, + d_model=4, + feat_out=8, + stochastic_depth_start_layer=start_layer, ) @pytest.mark.pleasefixme