Skip to content

Commit

Permalink
[AutoParallel] unify llama model
Browse files Browse the repository at this point in the history
  • Loading branch information
deepllz committed Mar 14, 2024
1 parent c406d90 commit 01d04b2
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 54 deletions.
8 changes: 4 additions & 4 deletions llm/llama/auto_parallel/run_pretrain_auto.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -29,7 +29,7 @@ export PYTHONPATH=../../../:$PYTHONPATH
# export FLAGS_call_stack_level=3
# export FLAGS_use_cuda_managed_memory=true

# export FLAGS_embedding_deterministic=1
# export FLAGS_embedding_deterministic=1
# export FLAGS_cudnn_deterministic=1
# export NVIDIA_TF32_OVERRIDE=0

Expand Down
122 changes: 72 additions & 50 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,26 @@ def swiglu(x, y=None):
]


def is_pp_enable():
mesh = fleet.auto.get_mesh()
return "pp" in mesh.dim_names

Check warning on line 87 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L86-L87

Added lines #L86 - L87 were not covered by tests


def get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
return mesh


def global_mesh_starts_with_pp():
mesh = fleet.auto.get_mesh()
if is_pp_enable():
return mesh.get_mesh_with_dim("pp")

Check warning on line 100 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L98-L100

Added lines #L98 - L100 were not covered by tests
else:
return mesh

Check warning on line 102 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L102

Added line #L102 was not covered by tests


def scaled_dot_product_attention(
query_states,
config,
Expand Down Expand Up @@ -800,21 +813,29 @@ def __init__(self, config: LlamaConfig):
[dist.Replicate(), dist.Shard(1)],
)

def get_layer_ipp(layer_index):
def get_layer_pp_info(layer_index):

Check warning on line 816 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L816

Added line #L816 was not covered by tests
mesh = fleet.auto.get_mesh()
if "pp" not in mesh.dim_names:
return None
if is_pp_enable() is False:
return None, False

Check warning on line 819 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L818-L819

Added lines #L818 - L819 were not covered by tests
else:
pp_degree = mesh.get_dim_size("pp")
layer_per_stage = math.ceil(config.num_hidden_layers / pp_degree)
return layer_index // layer_per_stage
layer_per_stage = math.ceil(

Check warning on line 822 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L822

Added line #L822 was not covered by tests
config.num_hidden_layers / pp_degree
)
input_need_reshard = layer_index % layer_per_stage == 0
return layer_index // layer_per_stage, input_need_reshard

Check warning on line 826 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L825-L826

Added lines #L825 - L826 were not covered by tests

decoder_layers = []
self.next_pp_stage_indexes = []
for i in range(config.num_hidden_layers):
pp_stage_id, input_need_reshard = get_layer_pp_info(i)
decoder_layers.append(

Check warning on line 832 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L828-L832

Added lines #L828 - L832 were not covered by tests
LlamaDecoderLayerAuto(config, False, pp_stage_id)
)
if input_need_reshard:
self.next_pp_stage_indexes.append(i)

Check warning on line 836 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L835-L836

Added lines #L835 - L836 were not covered by tests

self.layers = nn.LayerList(
[
LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, get_layer_ipp(i))
for i in range(config.num_hidden_layers)
]
)
self.layers = nn.LayerList(decoder_layers)

Check warning on line 838 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L838

Added line #L838 was not covered by tests
self.norm = LlamaRMSNormAuto(config)

self.gradient_checkpointing = False
Expand All @@ -840,13 +861,6 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
combined_attention_mask = _make_causal_mask(
input_shape, past_key_values_length=past_key_values_length
)
# NOTE(zhaoyingli): infer spmd does not support [seq_len, seq_len] --> [batch, 1, seq_len, seq_len] in data_parallel
combined_attention_mask = dist.shard_tensor(
combined_attention_mask,
get_mesh(),
[dist.Replicate(), dist.Replicate()],
)

expanded_attn_mask = expanded_attn_mask & combined_attention_mask
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
elif len(attention_mask.shape) == 3:
Expand Down Expand Up @@ -903,6 +917,20 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self.config.sequence_parallel:

Check warning on line 920 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L920

Added line #L920 was not covered by tests
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

Check warning on line 922 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L922

Added line #L922 was not covered by tests

global_mesh = global_mesh_starts_with_pp()
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

Check warning on line 926 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L924-L926

Added lines #L924 - L926 were not covered by tests

position_ids = dist.shard_tensor(

Check warning on line 928 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L928

Added line #L928 was not covered by tests
position_ids,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)

# embed positions
if attention_mask is None:
# [bs, seq_len]
Expand All @@ -914,22 +942,18 @@ def forward(
else:
alibi = None

if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
# NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel
position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()])

if self.config.sequence_parallel:
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

if self.config.use_flash_attention:
# attention_mask in flash_attn is always None for pretrain
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
attention_mask = dist.shard_tensor(

Check warning on line 952 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L952

Added line #L952 was not covered by tests
attention_mask,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)

hidden_states = inputs_embeds
hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements)
Expand All @@ -939,34 +963,34 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

pre_ipp = None
for idx, (decoder_layer) in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None

has_gradient = not hidden_states.stop_gradient
ipp = decoder_layer.ipp
if not is_pp_enable():
position_ids_input = position_ids
attention_mask_input = attention_mask

Check warning on line 975 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L972-L975

Added lines #L972 - L975 were not covered by tests
else:
position_ids_input = dist.reshard(

Check warning on line 977 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L977

Added line #L977 was not covered by tests
position_ids,
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)
attention_mask_input = dist.reshard(

Check warning on line 982 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L982

Added line #L982 was not covered by tests
attention_mask,
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)

if decoder_layer.ipp is not None and pre_ipp != decoder_layer.ipp:
if idx in self.next_pp_stage_indexes:

Check warning on line 988 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L988

Added line #L988 was not covered by tests
hidden_states = dist.reshard(
hidden_states,
get_mesh(decoder_layer.ipp),
get_mesh(ipp),
self.placements,
)
position_ids = dist.reshard(
position_ids,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
)
attention_mask = (
dist.reshard(
attention_mask,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
)
if attention_mask is not None
else attention_mask
)

if (
self.enable_recompute
Expand All @@ -977,8 +1001,8 @@ def forward(
layer_outputs = recompute(
decoder_layer,
hidden_states,
position_ids,
attention_mask,
position_ids_input,
attention_mask_input,
output_attentions,
past_key_value,
use_cache,
Expand All @@ -987,16 +1011,14 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
position_ids,
attention_mask,
position_ids_input,
attention_mask_input,
output_attentions,
past_key_value,
use_cache,
alibi=alibi,
)

pre_ipp = decoder_layer.ipp

if type(layer_outputs) is tuple:
hidden_states = layer_outputs[0]
else:
Expand Down

0 comments on commit 01d04b2

Please sign in to comment.