diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 53d6c97aa4049..cb4c981c2f4a3 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -154,7 +154,7 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None): dropout_rate = self.dropout_rate if self.training else 0 out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate) - + # this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version if mask is not None: all_masked_rows = torch.all(~mask, dim=-1) @@ -301,7 +301,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None): all_masked_rows.unsqueeze_(-1) all_masked_rows = all_masked_rows.expand(-1, out.size(1), -1, out.size(-1)) out = out.masked_fill(all_masked_rows, 0.0) - + out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) out = self.linear_out(out) # (batch, time1, d_model) else: