Skip to content

Commit

Permalink
[AutoParallel]:fix baichuan d2s fail (#9478)
Browse files Browse the repository at this point in the history
  • Loading branch information
blacksheep-Aristotle authored Nov 26, 2024
1 parent 131888e commit 6141f80
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from functools import partial
from typing import Optional, Tuple

import numpy as np
import paddle
import paddle.distributed.fleet.meta_parallel as mpu
import paddle.nn.functional as F
Expand Down Expand Up @@ -100,14 +101,14 @@ def swiglu(x, y=None):

def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
start = 2 ** (-(2 ** -(np.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]

if math.log2(n).is_integer():
if np.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
closest_power_of_2 = int(2 ** np.floor(np.log2(n)))
return (
_get_interleave_power_of_2(closest_power_of_2)
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def forward(
alibi = dist.shard_tensor(alibi, global_mesh, alibi_place)
else:
alibi = None
if self.config.use_flash_attention:
if self.config.use_flash_attention and not self.config.alibi:
# attention_mask in flash_attn is always None for pretrain
# atttenton_mask is used in scaled_dot_product_attention with alibi_tensor
attention_mask = None
Expand Down

0 comments on commit 6141f80

Please sign in to comment.