Skip to content

Commit

Permalink
add use_sdpa arg
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <goshagks@yandex.ru>
  • Loading branch information
WoodieDudy authored and kramarenko.gs committed Jul 18, 2024
1 parent fafcd0a commit 95ea37c
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 424 deletions.
111 changes: 65 additions & 46 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ class MultiHeadAttention(nn.Module):
dropout_rate (float): dropout rate
"""

def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0):
def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0, use_sdpa=False):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadAttention, self).__init__()
self.dropout_rate = dropout_rate
self.use_sdpa = use_sdpa
self.cache_drop_size = None
self.dropout_rate = dropout_rate
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
Expand Down Expand Up @@ -96,28 +97,28 @@ def forward_qkv(self, query, key, value):

return q, k, v

# def forward_attention(self, value, scores, mask):
# """Compute attention context vector.
# Args:
# value (torch.Tensor): (batch, time2, size)
# scores(torch.Tensor): (batch, time1, time2)
# mask(torch.Tensor): (batch, time1, time2)
# returns:
# value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores
# """
# n_batch = value.size(0)
# if mask is not None:
# mask = mask.unsqueeze(1) # (batch, 1, time1, time2)
# scores = scores.masked_fill(mask, -10000.0)
# attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
# else:
# attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)

# p_attn = self.dropout(attn)
# x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
# x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)

# return self.linear_out(x) # (batch, time1, d_model)
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value (torch.Tensor): (batch, time2, size)
scores(torch.Tensor): (batch, time1, time2)
mask(torch.Tensor): (batch, time1, time2)
returns:
value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores
"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1) # (batch, 1, time1, time2)
scores = scores.masked_fill(mask, -10000.0)
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)

p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

def forward(self, query, key, value, mask, pos_emb=None, cache=None):
"""Compute 'Scaled Dot Product Attention'.
Expand All @@ -133,21 +134,28 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
cache (torch.Tensor) : (batch, time_cache_next, size)
"""
key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)
n_batch = value.size(0)

if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)


# temporary until we solve this more gracefully
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
scale = 1 / self.s_d_k

if mask is not None:
mask = mask.unsqueeze(1).logical_not()

out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout_rate, scale=scale)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)

if self.use_sdpa:
scale = 1 / self.s_d_k
n_batch = value.size(0)

if mask is not None:
mask = mask.unsqueeze(1)

out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout_rate, scale=scale)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
out = self.forward_attention(v, scores, mask)

if cache is None:
return out
else:
Expand All @@ -170,10 +178,9 @@ class RelPositionMultiHeadAttention(MultiHeadAttention):
dropout_rate (float): dropout rate
"""

def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0):
def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0, use_sdpa=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len)
self.dropout_rate = dropout_rate
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len, use_sdpa=use_sdpa)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable biases are used in matrix c and matrix d
Expand Down Expand Up @@ -236,21 +243,33 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1))
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor

if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask, float("-inf"))
if self.use_sdpa:
scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1))
matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor

out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=self.dropout_rate, scale=scale_factor)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask.logical_not(), float("-inf"))

out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=self.dropout_rate, scale=scale_factor)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)
out = self.forward_attention(v, scores, mask)

if cache is None:
return out
Expand Down
210 changes: 0 additions & 210 deletions sdpa_testing/old_multi_head_attention.py

This file was deleted.

Loading

0 comments on commit 95ea37c

Please sign in to comment.