diff --git a/ppfleetx/configs/nlp/ernie/pretrain_ernie_base.yaml b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base.yaml index 836cf5e9c..020401a16 100644 --- a/ppfleetx/configs/nlp/ernie/pretrain_ernie_base.yaml +++ b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base.yaml @@ -69,9 +69,11 @@ Data: shuffle: False drop_last: True loader: - num_workers: 1 + num_workers: 0 return_list: False - collate_fn: ernie_collate_data + collate_fn: + name: ErnieCollateData + micro_batch_size: Eval: dataset: @@ -96,7 +98,9 @@ Data: loader: num_workers: 1 return_list: False - collate_fn: ernie_collate_data + collate_fn: + name: ErnieCollateData + micro_batch_size: 1 Optimizer: name: FusedAdamW diff --git a/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_175B_mp8_pp16.yaml b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_175B_mp8_pp16.yaml new file mode 100644 index 000000000..286d9acb4 --- /dev/null +++ b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_175B_mp8_pp16.yaml @@ -0,0 +1,44 @@ +_base_: ./pretrain_ernie_base.yaml + +Global: + global_batch_size: + local_batch_size: 512 + micro_batch_size: 1 + + +Model: + vocab_size: 40000 + hidden_size: 12288 + num_hidden_layers: 96 + num_attention_heads: 96 + intermediate_size: + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 512 + type_vocab_size: 4 + initializer_range: 0.02 + pad_token_id: 0 + task_type_vocab_size: 3 + task_id: 0 + use_task_id: True + use_recompute: True + + +Data: + Train: + dataset: + tokenizer_type: ernie-1.0-base-zh-cw + Eval: + dataset: + tokenizer_type: ernie-1.0-base-zh-cw + + +Distributed: + dp_degree: 1 + mp_degree: 8 + pp_degree: 16 + sharding: + sharding_degree: 1 + sharding_stage: 1 + sharding_offload: False diff --git a/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_single_card.yaml b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_345M_single_card.yaml similarity index 90% rename from ppfleetx/configs/nlp/ernie/pretrain_ernie_base_single_card.yaml rename to ppfleetx/configs/nlp/ernie/pretrain_ernie_base_345M_single_card.yaml index 3fd696b27..445b07c86 100644 --- a/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_single_card.yaml +++ b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_345M_single_card.yaml @@ -8,9 +8,9 @@ Global: Model: vocab_size: 40000 - hidden_size: 768 - num_hidden_layers: 1 - num_attention_heads: 12 + hidden_size: 1024 + num_hidden_layers: 24 + num_attention_heads: 16 intermediate_size: hidden_act: "gelu" hidden_dropout_prob: 0.1 @@ -24,6 +24,7 @@ Model: use_task_id: True use_recompute: False + Data: Train: dataset: diff --git a/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_3D.yaml b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_3D.yaml index 64bfc2a35..f00b5636d 100644 --- a/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_3D.yaml +++ b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_3D.yaml @@ -2,8 +2,8 @@ _base_: ./pretrain_ernie_base.yaml Global: global_batch_size: - local_batch_size: 4 - micro_batch_size: 4 + local_batch_size: 8 + micro_batch_size: 1 Model: @@ -34,9 +34,9 @@ Data: Distributed: - dp_degree: 1 - mp_degree: 8 - pp_degree: 1 + dp_degree: 2 + mp_degree: 2 + pp_degree: 2 sharding: sharding_degree: 1 sharding_stage: 1 diff --git a/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_6.7B_sharding16.yaml b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_6.7B_sharding16.yaml new file mode 100644 index 000000000..34ab852ea --- /dev/null +++ b/ppfleetx/configs/nlp/ernie/pretrain_ernie_base_6.7B_sharding16.yaml @@ -0,0 +1,43 @@ +_base_: ./pretrain_ernie_base.yaml + +Global: + global_batch_size: + local_batch_size: 512 + micro_batch_size: 1 + + +Model: + vocab_size: 40000 + hidden_size: 4096 + num_hidden_layers: 32 + num_attention_heads: 32 + intermediate_size: + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 512 + type_vocab_size: 4 + initializer_range: 0.02 + pad_token_id: 0 + task_type_vocab_size: 3 + task_id: 0 + use_task_id: True + use_recompute: True + +Data: + Train: + dataset: + tokenizer_type: ernie-1.0-base-zh-cw + Eval: + dataset: + tokenizer_type: ernie-1.0-base-zh-cw + + +Distributed: + dp_degree: 1 + mp_degree: 8 + pp_degree: 16 + sharding: + sharding_degree: 1 + sharding_stage: 1 + sharding_offload: False diff --git a/ppfleetx/data/__init__.py b/ppfleetx/data/__init__.py index d5adb5882..7c429250d 100644 --- a/ppfleetx/data/__init__.py +++ b/ppfleetx/data/__init__.py @@ -80,9 +80,16 @@ def build_dataloader(config, mode): if 'loader' in config[mode].keys(): config_loader = config[mode].loader config_loader = copy.deepcopy(config_loader) - collate_fn_name = config_loader.pop('collate_fn', None) - collate_fn = getattr( - utils, collate_fn_name) if collate_fn_name is not None else None + + collate_fn_cfg = config_loader.pop('collate_fn', None) + if isinstance(collate_fn_cfg, str): + collate_fn = getattr( + utils, collate_fn_cfg) if collate_fn_cfg is not None else None + elif isinstance(collate_fn_cfg, dict): + collate_fn_class_name = collate_fn_cfg.pop("name") + collate_fn = eval("utils.{}".format(collate_fn_class_name))( + **collate_fn_cfg) + logger.debug("build collate_fn({}) success...".format(collate_fn)) def worker_init_fn(worker_id): """ set seed in subproces for dataloader when num_workers > 0""" diff --git a/ppfleetx/data/dataset/ernie/dataset_utils.py b/ppfleetx/data/dataset/ernie/dataset_utils.py index 03431d46b..2c8a12d00 100644 --- a/ppfleetx/data/dataset/ernie/dataset_utils.py +++ b/ppfleetx/data/dataset/ernie/dataset_utils.py @@ -336,8 +336,10 @@ def create_masked_lm_predictions(tokens, return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) - num_to_predict = min(max_predictions_per_seq, - max(1, int(round(len(tokens) * masked_lm_prob)))) + # NOTE(shenliang03): to avoid num_to_predict < 1 + num_to_predict = max(1, + min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob))))) ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) if not geometric_dist: diff --git a/ppfleetx/data/dataset/ernie/ernie_dataset.py b/ppfleetx/data/dataset/ernie/ernie_dataset.py index fb298c922..370bc5329 100644 --- a/ppfleetx/data/dataset/ernie/ernie_dataset.py +++ b/ppfleetx/data/dataset/ernie/ernie_dataset.py @@ -284,7 +284,6 @@ def get_train_data_file(input_dir): if (os.path.isfile(os.path.join(input_dir, f)) and "_idx.npz" in str(f)) ] - # print(">>>> files", files) files = [x.replace("_idx.npz", "") for x in files] if len(files) > 1: diff --git a/ppfleetx/data/utils/batch_collate_fn.py b/ppfleetx/data/utils/batch_collate_fn.py index ad9864956..c75b612d6 100644 --- a/ppfleetx/data/utils/batch_collate_fn.py +++ b/ppfleetx/data/utils/batch_collate_fn.py @@ -95,35 +95,54 @@ def gpt_collate_fn(batch): return Tuple([Stack() for raw in zip(*batch)])(batch) -def ernie_collate_data(data, stack_fn=Stack()): - num_fields = len(data[0]) - out = [None] * num_fields - # 0. input_ids, - # 1. segment_ids, - # 2. input_mask, - # 3. masked_lm_positions, - # 4. masked_lm_labels, - # 5. next_sentence_labels - for i in (0, 1, 2, 5): - out[i] = stack_fn([x[i] for x in data]) - out[5] = out[5].reshape([-1, 1]) - batch_size, seq_length = out[0].shape - size = num_mask = sum(len(x[3]) for x in data) - # masked_lm_positions - # Organize as a 1D tensor for gather or use gather_nd - if size % 8 != 0: - size += 8 - (size % 8) - out[3] = np.full(size, 0, dtype=np.int32) - # masked_lm_labels - out[4] = np.full([size, 1], -1, dtype=np.int64) - mask_token_num = 0 - for i, x in enumerate(data): - for j, pos in enumerate(x[3]): - out[3][mask_token_num] = i * seq_length + pos - out[4][mask_token_num] = x[4][j] - mask_token_num += 1 - - return out +class ErnieCollateData(): + def __init__(self, micro_batch_size=1): + self.micro_batch_size = micro_batch_size + + def generate_data(self, data, stack_fn=Stack()): + num_fields = len(data[0]) + out = [None] * num_fields + # 0. input_ids, + # 1. segment_ids, + # 2. input_mask, + # 3. masked_lm_positions, + # 4. masked_lm_labels, + # 5. next_sentence_labels + for i in (0, 1, 2, 5): + out[i] = stack_fn([x[i] for x in data]) + out[5] = out[5].reshape([-1, 1]) + batch_size, seq_length = out[0].shape + size = num_mask = sum(len(x[3]) for x in data) + # masked_lm_positions + # Organize as a 1D tensor for gather or use gather_nd + if size % 8 != 0: + size += 8 - (size % 8) + out[3] = np.full(size, 0, dtype=np.int32) + + # masked_lm_labels + out[4] = np.full([size, 1], -1, dtype=np.int64) + mask_token_num = 0 + for i, x in enumerate(data): + for j, pos in enumerate(x[3]): + out[3][mask_token_num] = i * seq_length + pos + out[4][mask_token_num] = x[4][j] + mask_token_num += 1 + return out + + def __call__(self, data): + accumulate_steps = len(data) // self.micro_batch_size + if accumulate_steps == 1: + return self.generate_data(data) + else: + self.micro_batch_size = len(data) // accumulate_steps + all_data = [[] for _ in range(6)] + for acc_step in range(accumulate_steps): + tmp = self.generate_data( + data[acc_step * self.micro_batch_size:(acc_step + 1) * + self.micro_batch_size]) + for i in range(6): + all_data[i].append(tmp[i]) + return all_data def imagen_collate_fn(batch): diff --git a/ppfleetx/models/language_model/ernie/dygraph/hybrid_model.py b/ppfleetx/models/language_model/ernie/dygraph/hybrid_model.py index 259d10fad..42ea6e2ed 100644 --- a/ppfleetx/models/language_model/ernie/dygraph/hybrid_model.py +++ b/ppfleetx/models/language_model/ernie/dygraph/hybrid_model.py @@ -106,6 +106,7 @@ def forward(self, if input_ids is not None: input_shape = paddle.shape(input_ids) input_embeddings = self.word_embeddings(input_ids) + else: input_shape = paddle.shape(inputs_embeds)[:-1] input_embeddings = inputs_embeds @@ -127,6 +128,7 @@ def forward(self, if token_type_ids is None: token_type_ids = paddle.zeros(input_shape, dtype="int64") token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings if self.use_task_id: @@ -732,8 +734,6 @@ def forward(self, # masked_lm_loss = self.loss_func(prediction_scores, # masked_lm_labels, # ignore_index=-1) - # print("prediction_scores", prediction_scores.shape, masked_lm_labels.shape) - masked_lm_loss = F.cross_entropy( prediction_scores, masked_lm_labels, @@ -757,23 +757,9 @@ def embedding_weight(self): return self.word_embeddings.weight def forward(self, tensors): - print(">> tensors", tensors) - input_ids, token_type_ids, attention_mask, masked_positions = tensors - - # if input_ids is not None and inputs_embeds is not None: - # raise ValueError( - # "You cannot specify both input_ids and inputs_embeds at the same time.") - # elif input_ids is not None: - # input_shape = paddle.shape(input_ids) - # elif inputs_embeds is not None: - # input_shape = paddle.shape(inputs_embeds)[:-1] - # else: - # raise ValueError( - # "You have to specify either input_ids or inputs_embeds") + input_ids, token_type_ids, attention_mask = tensors past_key_values_length = None - # if past_key_values is not None: - # past_key_values_length = past_key_values[0][0].shape[2] if attention_mask is None: attention_mask = paddle.unsqueeze( @@ -821,28 +807,30 @@ def forward(self, tensors): class ErniePoolerPipe(ErniePooler): - def forward(self, sequence_output): + def forward(self, args): + sequence_output = args pooled_output = super().forward(sequence_output) return sequence_output, pooled_output -class ErniePretrainingHeadsPipe(ErniePretrainingHeads): - def forward(self, args): - sequence_output, pooled_output = args - prediction_scores, seq_relationship_score = super().forward( - sequence_output, pooled_output) - return prediction_scores, seq_relationship_score +class ErniePretrainingCriterionPipe(ErniePretrainingCriterionHybrid): + def __init__(self, *heads_args, **heads_kargs): + super(ErniePretrainingCriterionPipe, self).__init__() + self.heads = ErniePretrainingHeads(*heads_args, **heads_kargs) - -class ErniePretrainingCriterionHybridPipe(ErniePretrainingCriterionHybrid): def forward(self, outputs, data): - masked_lm_labels, next_sentence_labels = data - prediction_scores, seq_relationship_score = outputs + sequence_output, pooled_output = outputs + masked_lm_positions, masked_lm_labels, next_sentence_labels = data + + prediction_scores, seq_relationship_score = self.heads( + sequence_output, pooled_output, masked_lm_positions) + lm_loss, sop_loss = super().forward( prediction_scores=prediction_scores, seq_relationship_score=seq_relationship_score, masked_lm_labels=masked_lm_labels, next_sentence_labels=next_sentence_labels) + return lm_loss + sop_loss @@ -878,7 +866,7 @@ def __init__(self, pad_token_id=pad_token_id, weight_attr=None, task_type_vocab_size=task_type_vocab_size, - task_id=task_type_vocab_size, + task_id=task_id, use_task_id=use_task_id)) for _ in range(num_hidden_layers): @@ -902,20 +890,18 @@ def __init__(self, LayerNormPipe, normalized_shape=hidden_size)) self.descs.append(LayerDesc(ErniePoolerPipe, hidden_size=hidden_size)) - self.descs.append( - LayerDesc( - ErniePretrainingHeadsPipe, - hidden_size=hidden_size, - vocab_size=vocab_size, - activation=hidden_act, - embedding_weights=None, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.TruncatedNormal( - mean=0.0, std=initializer_range)))) + loss_fun = ErniePretrainingCriterionPipe( + hidden_size=hidden_size, + vocab_size=vocab_size, + activation=hidden_act, + embedding_weights=None, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.TruncatedNormal( + mean=0.0, std=initializer_range))) super().__init__( layers=self.descs, - loss_fn=ErniePretrainingCriterionHybridPipe, + loss_fn=loss_fun, topology=fleet.get_hybrid_communicate_group().topology(), seg_method="layer:TransformerEncoderLayer", recompute_interval=1 if use_recompute else 0, diff --git a/ppfleetx/models/language_model/ernie/ernie_module.py b/ppfleetx/models/language_model/ernie/ernie_module.py index c516979e7..e8aed197a 100644 --- a/ppfleetx/models/language_model/ernie/ernie_module.py +++ b/ppfleetx/models/language_model/ernie/ernie_module.py @@ -56,6 +56,8 @@ def process_data_configs(config): 'local_batch_size'] cfg_data[mode]['dataset'].setdefault('binary_head', cfg_global['binary_head']) + cfg_data[mode]['loader']['collate_fn'].setdefault( + 'micro_batch_size', cfg_global['micro_batch_size']) def process_model_configs(config): @@ -105,14 +107,13 @@ def forward(self, tokens): return self.model(tokens) def pretreating_batch(self, batch): - if self.configs.Distributed.pp_degree > 1: input_ids, segment_ids, input_mask, masked_lm_positions, \ masked_lm_labels, next_sentence_labels = batch - masked_lm_positions = masked_lm_positions.reshape_([1, -1]) - masked_lm_labels = masked_lm_labels.reshape_([1, -1]) - data = [(input_ids, segment_ids, input_mask, masked_lm_positions), - (masked_lm_labels, next_sentence_labels)] + data = [ + (input_ids, segment_ids, input_mask), + (masked_lm_positions, masked_lm_labels, next_sentence_labels) + ] return data else: return batch diff --git a/ppfleetx/utils/config.py b/ppfleetx/utils/config.py index ad7243032..08704d501 100644 --- a/ppfleetx/utils/config.py +++ b/ppfleetx/utils/config.py @@ -55,7 +55,7 @@ def process_dist_config(configs): assert nranks == dp_degree * other_degree, \ "Mismatched config using {} cards with dp_degree[{}]," \ "mp_degree[{}], pp_degree[{}] and sharding_degree[{}]".format(nranks, \ - dp_degree, mp_degree, pp_degree, _sharding_degree) + dp_degree, mp_degree, pp_degree, sharding_degree) if sharding_config['sharding_degree'] > 1 and reduce_overlap: if sharding_config['sharding_stage'] == 3 or sharding_config[ diff --git a/projects/ernie/pretrain_ernie_base.sh b/projects/ernie/pretrain_ernie_base.sh index 9697aea76..255e2fe13 100644 --- a/projects/ernie/pretrain_ernie_base.sh +++ b/projects/ernie/pretrain_ernie_base.sh @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -export PYTHONPATH=$PYTHONPATH:/workspace/workspace/PaddleNLP/ export CUDA_VISIBLE_DEVICES=1 python tools/train.py -c ppfleetx/configs/nlp/ernie/pretrain_ernie_base_single_card.yaml diff --git a/projects/ernie/pretrain_ernie_base_175B_mp8_pp16.sh b/projects/ernie/pretrain_ernie_base_175B_mp8_pp16.sh new file mode 100644 index 000000000..a0b1a8945 --- /dev/null +++ b/projects/ernie/pretrain_ernie_base_175B_mp8_pp16.sh @@ -0,0 +1,24 @@ +#! /bin/bash +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# export PYTHONPATH=$PYTHONPATH:/workspace/workspace/PaddleNLP/ + + +log_dir=log_hybrid +rm -rf $log_dir + +# 175B run_pretrain +python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \ + ./tools/train.py \ + -c ./ppfleetx/configs/nlp/ernie/pretrain_ernie_base_175B_mp8_pp16.yaml diff --git a/projects/ernie/pretrain_ernie_base_3D.sh b/projects/ernie/pretrain_ernie_base_3D.sh index 76a99a205..34c8357ce 100644 --- a/projects/ernie/pretrain_ernie_base_3D.sh +++ b/projects/ernie/pretrain_ernie_base_3D.sh @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#export PYTHONPATH=$PYTHONPATH:/workspace/workspace/PaddleNLP/ +# export PYTHONPATH=$PYTHONPATH:/workspace/workspace/PaddleNLP/ # export CUDA_VISIBLE_DEVICES=1 # python tools/train.py -c ppfleetx/configs/nlp/ernie/pretrain_ernie_base_single_card.yaml diff --git a/projects/ernie/pretrain_ernie_base_6.7B_sharding16.sh b/projects/ernie/pretrain_ernie_base_6.7B_sharding16.sh new file mode 100644 index 000000000..972cbef56 --- /dev/null +++ b/projects/ernie/pretrain_ernie_base_6.7B_sharding16.sh @@ -0,0 +1,25 @@ +#! /bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# export PYTHONPATH=$PYTHONPATH:/workspace/workspace/PaddleNLP/ + + +log_dir=log_hybrid +rm -rf $log_dir + +# 6.7B+sharding16 run_pretrain +python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \ + ./tools/train.py \ + -c ./ppfleetx/configs/nlp/ernie/pretrain_ernie_base_6.7B_sharding16.yaml