diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 3c42d5bf1213..b78f58284e24 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -969,7 +969,7 @@ def forward( ) # embed positions - if attention_mask is None: + if not self.config.use_flash_attention and attention_mask is None: # [bs, seq_len] attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)