Skip to content

Commit

Permalink
use torch scaled_dot_product_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
kramarenko.gs committed May 26, 2024
1 parent 1fa961b commit b4a1c74
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
q = q.transpose(1, 2) # (batch, time1, head, d_k)

n_batch_pos = pos_emb.size(0)
n_batch = value.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)

Expand All @@ -227,22 +228,24 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1))
matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor

scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)
if mask is not None:
mask = mask.unsqueeze(1)
attn_mask = matrix_bd.masked_fill(mask, float("-inf"))
else:
attn_mask = matrix_bd

out = self.forward_attention(v, scores, mask)
out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=attn_mask)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = torch.nan_to_num(out, nan=0.0)
out = self.linear_out(out) # (batch, time1, d_model)

if cache is None:
return out
Expand Down

0 comments on commit b4a1c74

Please sign in to comment.