Skip to content

Commit

Permalink
[transformer] fix sdpa u2pp training nan
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 19, 2024
1 parent d27e61d commit 9aaaab9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 17 deletions.
3 changes: 1 addition & 2 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,7 @@ def forward(
assert mask.dtype != torch.bool
mask = mask.unsqueeze(1)
# matrix_bd as a mask bias
mask = torch.where(mask == get_dtype_min(mask.dtype), mask,
matrix_bd / math.sqrt(self.d_k))
mask = (matrix_bd + mask) / math.sqrt(self.d_k)
output = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u,
k,
Expand Down
18 changes: 3 additions & 15 deletions wenet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,21 +310,9 @@ def log_add(*args) -> float:
return a_max + lsp


def get_dtype_min(
dtype: torch.dtype,
eps16: float = torch.finfo(torch.float16).min,
eps32: float = torch.finfo(torch.float32).min,
eps64: float = torch.finfo(torch.float64).min,
epsbf16: float = torch.finfo(torch.bfloat16).min,
):
if dtype == torch.float16:
return eps16
elif dtype == torch.float32:
return eps32
elif dtype == torch.float64:
return eps64
elif dtype == torch.bfloat16:
return epsbf16
def get_dtype_min(dtype: torch.dtype, ):
if dtype in [torch.float32, torch.bfloat16, torch.float16]:
return -1e+10
else:
raise RuntimeError(f"expected x to be floating-point, got {dtype}")

Expand Down

0 comments on commit 9aaaab9

Please sign in to comment.