Skip to content

Commit

Permalink
update sdpa usage
Browse files Browse the repository at this point in the history
Signed-off-by: kramarenko.gs <kramarenko.gs@skbkontur.ru>
  • Loading branch information
WoodieDudy authored and kramarenko.gs committed Aug 21, 2024
1 parent 7fcac13 commit d0ed1ce
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,26 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
q, k, v = self.forward_qkv(query, key, value)

if self.use_pytorch_sdpa:
scale = 1 / self.s_d_k
n_batch = value.size(0)

if mask is not None:
mask = mask.unsqueeze(1)
# add extra col for mask to handle problem with nan after solfmax
rows_all_false = torch.all(mask, dim=-1)
modified_tensor = torch.where(mask, torch.tensor(-10000.0), torch.tensor(0.0))
new_column = torch.where(rows_all_false, torch.tensor(10000.0), torch.tensor(-10000.0))
mask = torch.cat([modified_tensor, new_column.unsqueeze(-1)], dim=-1).to(mask.device)

dropout_rate = self.dropout_rate if self.training else 0

# add extra col for key and value to handle problem with nan after solfmax
extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device)
k = torch.cat([k, extra_column], dim=-2)

extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device)
v = torch.cat([v, extra_column], dim=-2)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=dropout_rate, scale=scale
q, k, v, attn_mask=mask, dropout_p=dropout_rate
)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
Expand Down Expand Up @@ -283,10 +294,21 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask, -10000.0)
# add extra col for mask (matrix_bd) to handle problem with nan after solfmax
rows_all_false = torch.all(mask, dim=-1)
new_column = torch.where(rows_all_false, torch.tensor(10000.0), torch.tensor(-10000.0))
new_column = new_column.repeat(1, self.h, 1)
matrix_bd = torch.cat([matrix_bd, new_column.unsqueeze(-1)], dim=-1).to(matrix_bd.device)

dropout_rate = self.dropout_rate if self.training else 0
# add extra col for key and value to handle problem with nan after solfmax
extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device)
k = torch.cat([k, extra_column], dim=-2)

extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device)
v = torch.cat([v, extra_column], dim=-2)
out = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate, scale=scale_factor
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate
)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
Expand Down Expand Up @@ -336,6 +358,7 @@ def __init__(
global_tokens_spacing=1,
global_attn_separate=False,
use_bias=True,
use_pytorch_sdpa=False,
):
"""Construct an RelPositionMultiHeadAttentionLongformer object."""
super().__init__(
Expand All @@ -346,7 +369,12 @@ def __init__(
pos_bias_v=pos_bias_v,
max_cache_len=max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=use_pytorch_sdpa,
)

if use_pytorch_sdpa:
raise NotImplementedError("Not implemented for Longformer yet")

self.att_context_size = att_context_size
self.global_tokens = global_tokens
self.global_tokens_spacing = global_tokens_spacing
Expand Down

0 comments on commit d0ed1ce

Please sign in to comment.