Skip to content

Commit

Permalink
Merge pull request #14559 from Nuullll/ipex-sdpa-fix
Browse files Browse the repository at this point in the history
[IPEX] Fix SDPA attn_mask dtype
  • Loading branch information
AUTOMATIC1111 authored Jan 6, 2024
2 parents 8b6848c + 16b4d2c commit b00b429
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions modules/xpu_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention(
# cast to same dtype first
key = key.to(query.dtype)
value = value.to(query.dtype)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(query.dtype)

N = query.shape[:-2] # Batch size
L = query.size(-2) # Target sequence length
Expand Down

0 comments on commit b00b429

Please sign in to comment.