diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 7611fd961ab6..23518858085a 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -21,6 +21,7 @@ from functools import partial from typing import Optional, Tuple +import numpy as np import paddle import paddle.distributed.fleet.meta_parallel as mpu import paddle.nn.functional as F @@ -100,14 +101,14 @@ def swiglu(x, y=None): def _get_interleave(n): def _get_interleave_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(np.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] - if math.log2(n).is_integer(): + if np.log2(n).is_integer(): return _get_interleave_power_of_2(n) else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) + closest_power_of_2 = int(2 ** np.floor(np.log2(n))) return ( _get_interleave_power_of_2(closest_power_of_2) + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index f33c088ceed0..de2ae508a397 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -1003,7 +1003,7 @@ def forward( alibi = dist.shard_tensor(alibi, global_mesh, alibi_place) else: alibi = None - if self.config.use_flash_attention: + if self.config.use_flash_attention and not self.config.alibi: # attention_mask in flash_attn is always None for pretrain # atttenton_mask is used in scaled_dot_product_attention with alibi_tensor attention_mask = None