Skip to content

Commit

Permalink
merge develop
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWooSJTU committed Oct 12, 2023
2 parents 86283ad + 74eb855 commit c3dcf12
Show file tree
Hide file tree
Showing 18 changed files with 155 additions and 80 deletions.
12 changes: 11 additions & 1 deletion llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def main():
)

if model_args.prefix_tuning:
if training_args.pipeline_parallel_degree > 1:
raise NotImplementedError("Prefix tuning is not implemented for pipeline parallelism.")

prefix_tuning_params = get_prefix_tuning_params(model)
prefix_config = PrefixConfig(
num_prefix_tokens=model_args.num_prefix_tokens,
Expand Down Expand Up @@ -309,13 +312,20 @@ def compute_metrics_do_generation(eval_preds):
# Create trainer
max_length = data_args.max_length if training_args.pipeline_parallel_degree > 1 else None
padding = "max_length" if training_args.pipeline_parallel_degree > 1 else True
if training_args.pipeline_parallel_degree > 1:
metrics = None
elif data_args.eval_with_do_generation:
metrics = compute_metrics_do_generation
else:
metrics = compute_metrics

trainer = CausalLMTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=dev_ds,
tokenizer=tokenizer,
compute_metrics=compute_metrics_do_generation if data_args.eval_with_do_generation else compute_metrics,
compute_metrics=metrics,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
max_length=max_length,
Expand Down
5 changes: 4 additions & 1 deletion llm/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
| ziqingyang/chinese-alpaca-13b |
| idea-ccnl/ziya-llama-13b-v1 |
| linly-ai/chinese-llama-2-7b |
| linly-ai/chinese-llama-2-13b |
| baichuan-inc/Baichuan-7B |
| baichuan-inc/Baichuan-13B-Base |
| baichuan-inc/Baichuan-13B-Chat |
Expand Down Expand Up @@ -51,7 +52,7 @@ Llama2 模型的权重的使用则需要遵循[License](../../paddlenlp/transfor

## 3. 预训练

预训练数据制作参考[此处](../../model_zoo/ernie-1.0/preprocess/docs/OpenWebText2.md)
数据详细制作流程可参考[此处](../../model_zoo/ernie-1.0/preprocess/README.md),例:OpenWebText2预训练数据制作参考[此处](../../model_zoo/ernie-1.0/preprocess/docs/OpenWebText2.md)

为了方便用户运行测试本模型,本项目提供了处理好的100k条doc的训练样本:
```shell
Expand Down Expand Up @@ -114,6 +115,8 @@ python -u -m paddle.distributed.launch \
3. `continue_training` 表示从现有的预训练模型加载训练。7b模型初始loss大概为1.99x, 随机初始化模型loss从11.x左右下降。
4. `use_fused_rms_norm` 需要安装[此目录](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/gpt-3/external_ops)下的自定义OP, `python setup.py install`。如果安装后仍然找不到算子,需要额外设置PYTHONPATH
5. 当前脚本为sharding版本,需要4D并行训练(数据、sharding、张量、流水线并行)的用户,请参考 `run_trainer_tp4pp2.sh`脚本。
6. 多机训练时,若各机器使用的训练数据文件位置相同(例如挂载共享硬盘情况),请指定`--share_folder true`使全局0号卡制作缓存数据。否则默认各台机器的0号卡独立制作缓存数据,
7. 若数据集文件夹中存在默认缓存文件夹`index-cache/`,则额外指定的`--data_cache`不生效,训练时优先加载默认缓存文件夹中的内容。

## 4. 模型精调
请参考[LLM全流程工具介绍](../README.md)
1 change: 1 addition & 0 deletions llm/llama/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def create_pretrained_dataset(
seq_length=data_args.max_seq_length,
seed=training_args.seed,
skip_warmup=data_args.skip_warmup,
share_folder=data_args.share_folder,
data_cache_path=data_args.data_cache,
need_data=need_data,
)
Expand Down
1 change: 0 additions & 1 deletion llm/llama/sft_pp_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 2,
Expand Down
53 changes: 16 additions & 37 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,19 +385,21 @@ def _preprocess(self, source):
self.attention_mask[:] = 0
self.tgt_generation_mask[:] = 0
pre_caches_length = 0 if not self.config.export_precache else self.pre_caches[0].shape[-2]
inputs = dybatch_preprocess(
self.tokenizer,
source,
self.config.src_length,
self.config.max_length,
self.architectures,
top_p=self.config.top_p,
temperature=self.config.temperature,
benchmark=self.config.benchmark,
pre_caches_length=pre_caches_length,
)

if "chatglm" in self.architectures:
inputs = dybatch_preprocess(
self.tokenizer,
source,
self.config.src_length,
self.config.max_length,
self.architectures,
top_p=self.config.top_p,
temperature=self.config.temperature,
benchmark=self.config.benchmark,
pre_caches_length=pre_caches_length,
)
if inputs["input_ids"].shape[0] < self.config.batch_size:
self.tgt_pos = self.tgt_pos[: inputs["input_ids"].shape[0]]
for i in range(inputs["input_ids"].shape[0]):
length = inputs["seq_len_encoder"][i][0]
self.attention_mask[i, 0, :length, :length] = 1
Expand Down Expand Up @@ -427,16 +429,6 @@ def _preprocess(self, source):

inputs["tgt_pos"] = self.tgt_pos
elif "bloom" in self.architectures:
inputs = dybatch_preprocess(
self.tokenizer,
source,
self.config.src_length,
self.config.max_length,
self.architectures,
top_p=self.config.top_p,
temperature=self.config.temperature,
benchmark=self.config.benchmark,
)
for i in range(inputs["input_ids"].shape[0]):
length = inputs["seq_len_encoder"][i][0]
self.attention_mask[i, :, :length, :length] = paddle.tril(
Expand All @@ -460,7 +452,7 @@ def _preprocess(self, source):
* block_size : (self.model_config.tensor_parallel_rank + 1)
* block_size,
]
alibi = alibi.reshape([inputs["input_ids"].shape[0], block_size, 1, self.config.max_length])
alibi = alibi.reshape([self.config.batch_size, block_size, 1, self.config.max_length])
inputs["position_ids"] = inputs["position_ids"][
self.model_config.tensor_parallel_rank
* block_size : (self.model.config.tensor_parallel_rank + 1)
Expand All @@ -469,15 +461,15 @@ def _preprocess(self, source):

alibi_encoder = alibi.expand(
[
inputs["input_ids"].shape[0],
self.config.batch_size,
self.model_config.n_head // self.model_config.tensor_parallel_degree,
self.config.total_max_length,
self.config.total_max_length,
]
)
alibi_decoder = alibi.expand(
[
inputs["input_ids"].shape[0],
self.config.batch_size,
self.model_config.n_head // self.model_config.tensor_parallel_degree,
1,
self.config.total_max_length,
Expand All @@ -491,18 +483,6 @@ def _preprocess(self, source):
)

else:
inputs = dybatch_preprocess(
self.tokenizer,
source,
self.config.src_length,
self.config.max_length,
self.architectures,
top_p=self.config.top_p,
temperature=self.config.temperature,
pre_caches_length=pre_caches_length,
benchmark=self.config.benchmark,
)

for i in range(inputs["input_ids"].shape[0]):
length = inputs["seq_len_encoder"][i][0]
self.attention_mask[i, 0, :length, :length] = paddle.tril(
Expand Down Expand Up @@ -618,7 +598,6 @@ def _infer(self, inputs):
for i in range(len(self.cache_kvs_shape)):
input_tensor = self.predictor.get_input_handle("cache_kvs_" + str(i))
input_tensor.share_external_data(self.cache_kvs[i])

input_tensor = self.predictor.get_input_handle("pre_ids")
input_tensor.share_external_data(self.pre_ids)

Expand Down
5 changes: 3 additions & 2 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from paddlenlp.datasets import InTokensIterableDataset
from paddlenlp.trainer import Trainer, TrainerCallback
from paddlenlp.trainer.trainer_utils import IterableDatasetShard, has_length
from paddlenlp.transformers import LlamaForCausalLMPipe
from paddlenlp.utils.log import logger


Expand Down Expand Up @@ -111,7 +112,7 @@ def get_lora_target_modules(model):
]
elif model.base_model_prefix == "bloom":
target_modules = [".*query_key_value.*", ".*dense.*", ".*dense_h_to_4h.*", ".*dense_4h_to_h.*"]
elif model.base_model_prefix == "llama":
elif model.base_model_prefix == "llama" or isinstance(model, LlamaForCausalLMPipe):
target_modules = [
".*q_proj.*",
".*v_proj.*",
Expand Down Expand Up @@ -183,7 +184,7 @@ def prediction_step(
prediction_loss_only: bool,
ignore_keys=None,
):
if prediction_loss_only:
if prediction_loss_only or self.args.pipeline_parallel_degree > 1:
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
elif not self.do_generation:
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
Expand Down
1 change: 0 additions & 1 deletion model_zoo/ernie-1.0/preprocess/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ python -u create_pretraining_data.py \
--data_format "JSON" \
--json_key "text" \
--data_impl "mmap" \
--cn_seg_func "jieba" \
--append_eos \
--log_interval 5 \
--workers 40
Expand Down
10 changes: 8 additions & 2 deletions model_zoo/ernie-1.0/preprocess/docs/WuDaoCorpusBase.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ python ./trans_to_json.py \
下面是针对训练任务的数据集应用。

* llama为例

注:若使用llama模型,则不需要提前进行分词,请将WuDaoCorpus2.0_base_200G中的json文件预处理为如下格式的jsonl文件:
```
{"text": "飞桨是功能完备、开源开放的产业级深度学习平台。飞桨拥有..."}
{"text": "PaddleNLP是自然语言..."}
```

之后利用如下脚本将对应的jsonl文件转化为.bin & .idx文件。
```shell
python -u create_pretraining_data.py \
--model_name "idea-ccnl/ziya-llama-13b-v1" \
Expand All @@ -58,8 +66,6 @@ python -u create_pretraining_data.py \
--data_format "JSON" \
--json_key "text" \
--data_impl "mmap" \
--cn_seg_func "jieba" \
--cn_splited \
--append_eos \
--log_interval 10000 \
--workers 48
Expand Down
32 changes: 27 additions & 5 deletions paddlenlp/data/blendable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def print_rank_0(*args, **kwargs):


class BlendableDataset(paddle.io.Dataset):
def __init__(self, datasets, weights, size, *, data_cache_path=None):
def __init__(self, datasets, weights, size, share_folder, *, data_cache_path=None):

self.datasets = datasets
num_datasets = len(datasets)
Expand Down Expand Up @@ -82,7 +82,15 @@ def _build_indices():
cache_hit = os.path.isfile(index_path) and os.path.isfile(sample_index_path)
# cache_success = True
# if paddle.distributed.get_rank() == 0 and not cache_hit:
if local_rank == 0 and not cache_hit:
check_rank_flag = not cache_hit and local_rank == 0
if share_folder:
check_rank_flag = not cache_hit and paddle.distributed.get_rank() == 0

print(
f"searching for blendable dataset, cache_hit={cache_hit}, share_folder {share_folder}, check_rank_flag {check_rank_flag}",
flush=True,
)
if check_rank_flag:
print(
" > WARNING: could not find index map files for blendable"
" dataset, building indices on rank 0 ...",
Expand Down Expand Up @@ -114,9 +122,19 @@ def _build_indices():
# print_rank_0("Data index creation unsuccessful, exiting.")
# exit()

if paddle.distributed.get_world_size() > 1:
if paddle.in_dynamic_mode():
paddle.distributed.barrier()
else:
while True:
if (not os.path.isfile(index_path)) or (not os.path.isfile(sample_index_path)):
print("building indices on rank 0 ...", flush=True)
time.sleep(3)
else:
try:
np.load(index_path, allow_pickle=True, mmap_mode="r")
print("build success", flush=True)
break
except Exception:
print("%s file is still writing or damaged, please wait for a moment." % index_path)
time.sleep(3)

# paddle.distributed.barrier()
# Load on all ranks.
Expand All @@ -128,6 +146,10 @@ def _build_indices():
self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode="r")
assert self.dataset_sample_index.size == self.size
else:
print_rank_0(
"building indices for the blendable dataset, Since --data_cache is not specified, the index file will not be stored.",
flush=True,
)
self.dataset_index, self.dataset_sample_index = _build_indices()

# Check size
Expand Down
Loading

0 comments on commit c3dcf12

Please sign in to comment.