Skip to content

Commit

Permalink
[XPU] plain softmax_mask_fuse_upper_triangle implement for flash atte…
Browse files Browse the repository at this point in the history
  • Loading branch information
houj04 authored Oct 25, 2023
1 parent 84e8b05 commit 1c1f1d2
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=True):
g_enable_mem_efficient = original_enable_mem_efficient


# special for XPU device
def get_triangle_upper_mask(x):
mask = paddle.full_like(x, -1e4)
mask.stop_gradient = True
mask = paddle.triu(mask, diagonal=1)
mask.stop_gradient = True
return mask


def _math_attention(
query,
key,
Expand All @@ -65,11 +74,19 @@ def _math_attention(
product = paddle.matmul(
x=query * (head_dim**-0.5), y=key, transpose_y=True
)
weights = (
paddle.incubate.softmax_mask_fuse_upper_triangle(product)
if causal
else F.softmax(product)
)

if not causal:
weights = F.softmax(product)
else:
# special for XPU device
place = paddle.get_device()
if "xpu" in place:
# softmax_mask_fuse_upper_triangle is not supported on XPU, use plain implementation
mask = get_triangle_upper_mask(product)
product = product + mask
weights = F.softmax(product)
else:
weights = paddle.incubate.softmax_mask_fuse_upper_triangle(product)
if dropout_rate > 0.0:
weights = F.dropout(
weights, dropout_rate, training=training, mode="upscale_in_train"
Expand Down

0 comments on commit 1c1f1d2

Please sign in to comment.