Skip to content

Commit

Permalink
handle nans after softmax in a different way
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <goshagks@gmail.com>
  • Loading branch information
WoodieDudy committed Aug 26, 2024
1 parent bfdf629 commit 18559da
Showing 1 changed file with 24 additions and 30 deletions.
54 changes: 24 additions & 30 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
'PositionalEncoding',
]

inf_val = 10000.0


class MultiHeadAttention(nn.Module):
"""Multi-Head Attention layer of Transformer.
Expand Down Expand Up @@ -111,7 +113,7 @@ def forward_attention(self, value, scores, mask):
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1) # (batch, 1, time1, time2)
scores = scores.masked_fill(mask, -10000.0)
scores = scores.masked_fill(mask, -inf_val)
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
Expand Down Expand Up @@ -148,22 +150,17 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
n_batch = value.size(0)

if mask is not None:
mask = mask.unsqueeze(1)
# add extra col for mask to handle problem with nan after solfmax
rows_all_false = torch.all(mask, dim=-1)
modified_tensor = torch.where(mask, torch.tensor(-10000.0), torch.tensor(0.0))
new_column = torch.where(rows_all_false, torch.tensor(10000.0), torch.tensor(-10000.0))
mask = torch.cat([modified_tensor, new_column.unsqueeze(-1)], dim=-1).to(mask.device)
mask = ~mask.unsqueeze(1)

dropout_rate = self.dropout_rate if self.training else 0

# add extra col for key and value to handle problem with nan after solfmax
extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device)
k = torch.cat([k, extra_column], dim=-2)

extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device)
v = torch.cat([v, extra_column], dim=-2)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate)

# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
all_masked_rows = torch.all(~mask, dim=-1)
all_masked_rows.unsqueeze_(-1)
out = out.masked_fill(all_masked_rows, 0.0)

out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
Expand Down Expand Up @@ -291,23 +288,20 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):

if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask, -10000.0)
# add extra col for mask (matrix_bd) to handle problem with nan after solfmax
rows_all_false = torch.all(mask, dim=-1)
new_column = torch.where(rows_all_false, torch.tensor(10000.0), torch.tensor(-10000.0))
new_column = new_column.repeat(1, self.h, 1)
matrix_bd = torch.cat([matrix_bd, new_column.unsqueeze(-1)], dim=-1).to(matrix_bd.device)
matrix_bd.masked_fill_(mask, -inf_val)

dropout_rate = self.dropout_rate if self.training else 0
# add extra col for key and value to handle problem with nan after solfmax
extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device)
k = torch.cat([k, extra_column], dim=-2)

extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device)
v = torch.cat([v, extra_column], dim=-2)
out = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate
)

# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
all_masked_rows = torch.all(mask, dim=-1)
all_masked_rows.unsqueeze_(-1)
all_masked_rows = all_masked_rows.expand(-1, out.size(1), -1, out.size(-1))
out = out.masked_fill(all_masked_rows, 0.0)

out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
Expand Down Expand Up @@ -444,14 +438,14 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):
# (batch, head, time, 2w + 1)

# mask invalid positions
scores[:, :, :, :start_pos] = -10000.0
scores[:, :, :, end_pos + 1 :] = -10000.0
scores[:, :, :, :start_pos] = -inf_val
scores[:, :, :, end_pos + 1 :] = -inf_val

# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (bsz x seq_len) to (bsz x num_heads x seqlen x hidden_size)
mask = mask.unsqueeze(dim=1).unsqueeze(dim=-1)
# cast to float/half then replace 1's with -inf
float_mask = mask.type_as(scores).masked_fill(mask, -10000.0)
float_mask = mask.type_as(scores).masked_fill(mask, -inf_val)
ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
# diagonal mask with zeros everywhere and -inf inplace of padding
d_mask = self.sliding_chunks_matmul_qk(ones, float_mask, w, padding_value=0.0)
Expand Down Expand Up @@ -984,7 +978,7 @@ def create_pe(self, positions, dtype):
pe = torch.zeros(pos_length, self.d_model, device=positions.device)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32, device=positions.device)
* -(math.log(10000.0) / self.d_model)
* -(math.log(inf_val) / self.d_model)
)
pe[:, 0::2] = torch.sin(positions * div_term)
pe[:, 1::2] = torch.cos(positions * div_term)
Expand Down

0 comments on commit 18559da

Please sign in to comment.