Skip to content

Commit

Permalink
[AutoParallel] unify llama model (#8127)
Browse files Browse the repository at this point in the history
* [AutoParallel] unify llama model

* attention_mask can be None
  • Loading branch information
deepllz authored Mar 15, 2024
1 parent c406d90 commit d83bd5e
Showing 1 changed file with 70 additions and 48 deletions.
118 changes: 70 additions & 48 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,26 @@ def swiglu(x, y=None):
]


def is_pp_enable():
mesh = fleet.auto.get_mesh()
return "pp" in mesh.dim_names


def get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
return mesh


def global_mesh_starts_with_pp():
mesh = fleet.auto.get_mesh()
if is_pp_enable():
return mesh.get_mesh_with_dim("pp")
else:
return mesh


def scaled_dot_product_attention(
query_states,
config,
Expand Down Expand Up @@ -800,21 +813,25 @@ def __init__(self, config: LlamaConfig):
[dist.Replicate(), dist.Shard(1)],
)

def get_layer_ipp(layer_index):
def get_layer_pp_info(layer_index):
mesh = fleet.auto.get_mesh()
if "pp" not in mesh.dim_names:
return None
if is_pp_enable() is False:
return None, False
else:
pp_degree = mesh.get_dim_size("pp")
layer_per_stage = math.ceil(config.num_hidden_layers / pp_degree)
return layer_index // layer_per_stage

self.layers = nn.LayerList(
[
LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, get_layer_ipp(i))
for i in range(config.num_hidden_layers)
]
)
input_need_reshard = layer_index % layer_per_stage == 0
return layer_index // layer_per_stage, input_need_reshard

decoder_layers = []
self.next_pp_stage_indexes = []
for i in range(config.num_hidden_layers):
pp_stage_id, input_need_reshard = get_layer_pp_info(i)
decoder_layers.append(LlamaDecoderLayerAuto(config, False, pp_stage_id))
if input_need_reshard:
self.next_pp_stage_indexes.append(i)

self.layers = nn.LayerList(decoder_layers)
self.norm = LlamaRMSNormAuto(config)

self.gradient_checkpointing = False
Expand All @@ -840,13 +857,6 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
combined_attention_mask = _make_causal_mask(
input_shape, past_key_values_length=past_key_values_length
)
# NOTE(zhaoyingli): infer spmd does not support [seq_len, seq_len] --> [batch, 1, seq_len, seq_len] in data_parallel
combined_attention_mask = dist.shard_tensor(
combined_attention_mask,
get_mesh(),
[dist.Replicate(), dist.Replicate()],
)

expanded_attn_mask = expanded_attn_mask & combined_attention_mask
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
elif len(attention_mask.shape) == 3:
Expand Down Expand Up @@ -903,6 +913,20 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self.config.sequence_parallel:
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

global_mesh = global_mesh_starts_with_pp()
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

position_ids = dist.shard_tensor(
position_ids,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)

# embed positions
if attention_mask is None:
# [bs, seq_len]
Expand All @@ -914,22 +938,18 @@ def forward(
else:
alibi = None

if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
# NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel
position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()])

if self.config.sequence_parallel:
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

if self.config.use_flash_attention:
# attention_mask in flash_attn is always None for pretrain
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
attention_mask = dist.shard_tensor(
attention_mask,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)

hidden_states = inputs_embeds
hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements)
Expand All @@ -939,33 +959,37 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

pre_ipp = None
for idx, (decoder_layer) in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None

has_gradient = not hidden_states.stop_gradient

if decoder_layer.ipp is not None and pre_ipp != decoder_layer.ipp:
hidden_states = dist.reshard(
hidden_states,
get_mesh(decoder_layer.ipp),
self.placements,
)
position_ids = dist.reshard(
ipp = decoder_layer.ipp
if not is_pp_enable():
position_ids_input = position_ids
attention_mask_input = attention_mask
else:
position_ids_input = dist.reshard(
position_ids,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)
attention_mask = (
attention_mask_input = (
dist.reshard(
attention_mask,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)
if attention_mask is not None
else attention_mask
else None
)

if idx in self.next_pp_stage_indexes:
hidden_states = dist.reshard(
hidden_states,
get_mesh(ipp),
self.placements,
)

if (
Expand All @@ -977,8 +1001,8 @@ def forward(
layer_outputs = recompute(
decoder_layer,
hidden_states,
position_ids,
attention_mask,
position_ids_input,
attention_mask_input,
output_attentions,
past_key_value,
use_cache,
Expand All @@ -987,16 +1011,14 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
position_ids,
attention_mask,
position_ids_input,
attention_mask_input,
output_attentions,
past_key_value,
use_cache,
alibi=alibi,
)

pre_ipp = decoder_layer.ipp

if type(layer_outputs) is tuple:
hidden_states = layer_outputs[0]
else:
Expand Down

0 comments on commit d83bd5e

Please sign in to comment.