Skip to content

Commit

Permalink
Add static graph support for "scaled_dot_product_attention" (#59498)
Browse files Browse the repository at this point in the history
* Added static graph support for 'scaled_dot_product_attention'

* Add static graph support for "scaled_dot_product_attention"
  • Loading branch information
lchdl authored Nov 30, 2023
1 parent 855e51e commit 0846112
Showing 1 changed file with 56 additions and 16 deletions.
72 changes: 56 additions & 16 deletions python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,22 +498,62 @@ def scaled_dot_product_attention(
>>> print(output)
>>> # doctest: -SKIP
"""

if attn_mask is None:
# downgraded to ordinary flash attention implementation
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
return out
else:
fixed_seed_offset = (None,)
return_softmax = False
rng_name = ""
out, _ = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
attn_mask,
dropout_p,
is_causal,
return_softmax,
not training,
rng_name,
)
return out
if in_dynamic_mode():
fixed_seed_offset = (None,)
return_softmax = False
rng_name = ""
out, _ = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
attn_mask,
dropout_p,
is_causal,
return_softmax,
not training,
rng_name,
)
return out
else:
helper = LayerHelper('flash_attn', **locals())
dtype = helper.input_dtype(input_param_name='q')
out = helper.create_variable_for_type_inference(dtype)
softmax = helper.create_variable_for_type_inference(dtype)
softmax_lse = helper.create_variable_for_type_inference(
paddle.float32
)
seed_offset = helper.create_variable_for_type_inference(
paddle.int64
)
inputs = {
'q': query,
'k': key,
'v': value,
'attn_mask': attn_mask,
}
outputs = {
'out': out,
'softmax': softmax,
'softmax_lse': softmax_lse,
'seed_offset': seed_offset,
}
helper.append_op(
type='flash_attn',
inputs=inputs,
outputs=outputs,
attrs={
'dropout': dropout_p,
'causal': is_causal,
'return_softmax': False,
'is_test': not training,
'rng_name': '',
},
)
return out

0 comments on commit 0846112

Please sign in to comment.