Skip to content

Commit

Permalink
replace -inf with -10k; manual dropout; fix masking
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <goshagks@yandex.ru>
  • Loading branch information
kramarenko.gs committed Aug 12, 2024
1 parent 2c87bef commit a6d8c52
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ class MultiHeadAttention(nn.Module):
use_bias (bool): whether to remove bias in linear and conv layers
"""

def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0, use_sdpa=False, use_bias=True):
def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0, use_bias=True, use_pytorch_sdpa=True):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadAttention, self).__init__()
self.use_sdpa = use_sdpa
self.use_pytorch_sdpa = use_pytorch_sdpa
self.cache_drop_size = None
self.use_bias = use_bias
self.dropout_rate = dropout_rate
Expand Down Expand Up @@ -144,15 +144,16 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)

if self.use_sdpa:
if self.use_pytorch_sdpa:
scale = 1 / self.s_d_k
n_batch = value.size(0)

if mask is not None:
mask = mask.unsqueeze(1)

dropout_rate = self.dropout_rate if self.training else 0
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=self.dropout_rate, scale=scale
q, k, v, attn_mask=mask, dropout_p=dropout_rate, scale=scale
)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
Expand Down Expand Up @@ -184,7 +185,7 @@ class RelPositionMultiHeadAttention(MultiHeadAttention):
"""

def __init__(
self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0, use_bias=True, use_sdpa=False
self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0, use_bias=True, use_pytorch_sdpa=True
):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(
Expand All @@ -193,7 +194,7 @@ def __init__(
dropout_rate=dropout_rate,
max_cache_len=max_cache_len,
use_bias=use_bias,
use_sdpa=use_sdpa,
use_pytorch_sdpa=use_pytorch_sdpa,
)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
Expand Down Expand Up @@ -267,16 +268,17 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)

if self.use_sdpa:
if self.use_pytorch_sdpa:
scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1))
matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor

if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask.logical_not(), float("-inf"))
matrix_bd.masked_fill_(mask, -10000.0)

dropout_rate = self.dropout_rate if self.training else 0
out = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=self.dropout_rate, scale=scale_factor
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate, scale=scale_factor
)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
Expand Down

0 comments on commit a6d8c52

Please sign in to comment.