Skip to content

Commit

Permalink
use pytorch sdpa
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <goshagks@gmail.com>
  • Loading branch information
WoodieDudy committed Aug 12, 2024
1 parent d6cfdc0 commit 741be10
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ model:
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
use_pytorch_sdpa: true

# Convolution module's params
conv_kernel_size: 9
Expand Down
8 changes: 8 additions & 0 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin):
Defaults to 1.
global_attn_separate (bool): whether the q, k, v layers used for global tokens should be separate.
Defaults to False.
use_pytorch_sdpa (bool): use torch sdpa instead of manual attention
Defaults to True.
"""

Expand Down Expand Up @@ -295,6 +297,7 @@ def __init__(
global_tokens: int = 0,
global_tokens_spacing: int = 1,
global_attn_separate: bool = False,
use_pytorch_sdpa: bool = True,
):
super().__init__()
d_ff = d_model * ff_expansion_factor
Expand All @@ -309,6 +312,7 @@ def __init__(
self.global_tokens = global_tokens
self.global_attn_separate = global_attn_separate
self.global_tokens_spacing = global_tokens_spacing
self.use_pytorch_sdpa = use_pytorch_sdpa

# Setting up the att_context_size
(
Expand Down Expand Up @@ -430,6 +434,7 @@ def __init__(
pos_bias_v=pos_bias_v,
att_context_size=self.att_context_size,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
self.layers.append(layer)

Expand Down Expand Up @@ -1028,6 +1033,7 @@ def change_attention_model(
max_cache_len=att_context_size[0],
pos_bias_u=None,
pos_bias_v=None,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
elif self_attention_model == 'rel_pos_local_attn':
new_attn = RelPositionMultiHeadAttentionLongformer(
Expand All @@ -1038,13 +1044,15 @@ def change_attention_model(
att_context_size=att_context_size,
pos_bias_u=None,
pos_bias_v=None,
# use_pytorch_sdpa=self.use_pytorch_sdpa,
)
elif self_attention_model == 'abs_pos':
new_attn = MultiHeadAttention(
n_head=self._cfg.n_heads,
n_feat=self._cfg.d_model,
dropout_rate=self._cfg.dropout_att,
max_cache_len=att_context_size[0],
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
else:
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/asr/parts/submodules/conformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ def __init__(
pos_bias_v=None,
att_context_size=[-1, -1],
use_bias=True,
use_pytorch_sdpa=True,
):
super(ConformerLayer, self).__init__()

self.use_pytorch_sdpa = use_pytorch_sdpa
self.self_attention_model = self_attention_model
self.n_heads = n_heads
self.fc_factor = 0.5
Expand Down Expand Up @@ -111,6 +113,7 @@ def __init__(
pos_bias_v=pos_bias_v,
max_cache_len=MHA_max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
elif self_attention_model == 'rel_pos_local_attn':
self.self_attn = RelPositionMultiHeadAttentionLongformer(
Expand All @@ -133,6 +136,7 @@ def __init__(
dropout_rate=dropout_att,
max_cache_len=MHA_max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
else:
raise ValueError(
Expand Down
54 changes: 45 additions & 9 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ class MultiHeadAttention(nn.Module):
use_bias (bool): whether to remove bias in linear and conv layers
"""

def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0, use_bias=True):
def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0, use_bias=True, use_pytorch_sdpa=True):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadAttention, self).__init__()
self.use_pytorch_sdpa = use_pytorch_sdpa
self.cache_drop_size = None
self.use_bias = use_bias
self.dropout_rate = dropout_rate
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
Expand Down Expand Up @@ -141,8 +143,24 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
# temporary until we solve this more gracefully
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
out = self.forward_attention(v, scores, mask)

if self.use_pytorch_sdpa:
scale = 1 / self.s_d_k
n_batch = value.size(0)

if mask is not None:
mask = mask.unsqueeze(1)

dropout_rate = self.dropout_rate if self.training else 0
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=dropout_rate, scale=scale
)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
out = self.forward_attention(v, scores, mask)

if cache is None:
return out
else:
Expand All @@ -166,14 +184,17 @@ class RelPositionMultiHeadAttention(MultiHeadAttention):
use_bias (bool): whether to apply bias in linear and conv layers of MultiHeadAttention
"""

def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0, use_bias=True):
def __init__(
self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0, use_bias=True, use_pytorch_sdpa=True
):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(
n_head=n_head,
n_feat=n_feat,
dropout_rate=dropout_rate,
max_cache_len=max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=use_pytorch_sdpa,
)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
Expand Down Expand Up @@ -228,6 +249,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
q = q.transpose(1, 2) # (batch, time1, head, d_k)

n_batch_pos = pos_emb.size(0)
n_batch = value.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)

Expand All @@ -240,18 +262,32 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]

scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)
if self.use_pytorch_sdpa:
scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1))
matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor

if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask, -10000.0)

out = self.forward_attention(v, scores, mask)
dropout_rate = self.dropout_rate if self.training else 0
out = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate, scale=scale_factor
)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)
out = self.forward_attention(v, scores, mask)

if cache is None:
return out
Expand Down

0 comments on commit 741be10

Please sign in to comment.