Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
  • Loading branch information
WoodieDudy committed Aug 26, 2024
1 parent 18559da commit 58252d8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):

dropout_rate = self.dropout_rate if self.training else 0
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate)

# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
all_masked_rows = torch.all(~mask, dim=-1)
Expand Down Expand Up @@ -301,7 +301,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
all_masked_rows.unsqueeze_(-1)
all_masked_rows = all_masked_rows.expand(-1, out.size(1), -1, out.size(-1))
out = out.masked_fill(all_masked_rows, 0.0)

out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
Expand Down

0 comments on commit 58252d8

Please sign in to comment.