diff --git a/paddlenlp/trainer/auto_training_args.py b/paddlenlp/trainer/auto_training_args.py index ee0a5c6c503e..eaa394b1c4a2 100644 --- a/paddlenlp/trainer/auto_training_args.py +++ b/paddlenlp/trainer/auto_training_args.py @@ -14,6 +14,7 @@ from dataclasses import dataclass, field +from .trainer_utils import split_parallel_config from .training_args import TrainingArguments from .utils import add_start_docstrings @@ -68,3 +69,7 @@ def __post_init__(self): if self.fused_linear: fused_passes.enable = True fused_passes.fused_passes_list.append("fused_gemm_epilogue_pass") + + mp_configs = split_parallel_config(self.tensor_parallel_config) + if "replace_with_parallel_cross_entropy" in mp_configs: + self.strategy.mp_optimization.replace_with_parallel_cross_entropy = True diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 5d1dad82a831..6c1d5e6e7333 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -266,6 +266,7 @@ class TrainingArguments: sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False. sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False. sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False. + replace_with_parallel_cross_entropy, it replaces 'cross_entropy_with_softmax' OP with 'c_softmax_with_cross_entropy' OP in PIR static graph, which can improve model parallel performance. pipeline_parallel_config (`str`, *optional*)( Some additional config it highly affect the useage of pipeline parallel, we provide some option to config it. following config is support: @@ -681,6 +682,7 @@ class TrainingArguments: "sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.\n" "sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.\n" "sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.\n" + "replace_with_parallel_cross_entropy, it replaces 'cross_entropy_with_softmax' OP with 'c_softmax_with_cross_entropy' OP in PIR static graph, which can improve model parallel performance.\n" ) }, ) @@ -1565,6 +1567,7 @@ def is_segment_parallel_supported(): "enable_delay_scale_loss", # "enable_mp_skip_c_identity", # "enable_mp_fused_linear_param_grad_add", + "replace_with_parallel_cross_entropy", ]: raise ValueError( f"Found unknown tensor parallell config {x}, " diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 2dc44fb57cec..b5162685c5e1 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -763,80 +763,104 @@ function llama_pir_auto_fuse_ffn_attention_qkv_MP2() { auto_task_name="llama_pir_auto_fuse_ffn_attention_qkv_MP2" auto_case_out_dir="auto_output/$auto_task_name" auto_case_log_dir="auto_output/$auto_task_name""_log" - rm -rf $auto_case_out_dir - rm -rf $auto_case_log_dir - - python -u -m paddle.distributed.launch \ - --gpus "0,1" \ - --log_dir $auto_case_log_dir \ - run_pretrain_auto.py \ - --model_name_or_path "facebook/llama-7b" \ - --tokenizer_name_or_path "facebook/llama-7b" \ - --input_dir "./data" \ - --output_dir $auto_case_out_dir \ - --split 949,50,1 \ - --weight_decay 0.01 \ - --warmup_ratio 0.01 \ - --warmup_steps 30 \ - --max_grad_norm 0.0 \ - --learning_rate 3e-05 \ - --min_learning_rate 3e-06 \ - --max_steps 5 \ - --logging_steps 1 \ - --eval_steps 1000 \ - --save_steps 3 \ - --continue_training 0 \ - --do_train true \ - --do_eval false \ - --do_predict false \ - --disable_tqdm true \ - --skip_profile_timer true \ - --save_total_limit 2 \ - --device gpu \ - --disable_tqdm true \ - --dataloader_num_workers 1 \ - --distributed_dataloader 0 \ - --enable_auto_parallel 1 \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --per_device_eval_batch_size 2 \ - --recompute false \ - --recompute_use_reentrant true \ - --recompute_granularity full \ - --pp_recompute_interval 0 \ - --bf16 0 \ - --fp16_opt_level "O2" \ - --amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \ - --amp_custom_white_list "lookup_table" "lookup_table_v2" \ - --amp_master_grad false \ - --fuse_attention_ffn false \ - --fuse_attention_qkv false \ - --use_flash_attention false \ - --use_fused_rope true \ - --use_fused_rms_norm true \ - --max_seq_length 4096 \ - --sequence_parallel false \ - --pipeline_parallel_degree 1 \ - --sharding_parallel_degree 1 \ - --tensor_parallel_degree 2 \ - --virtual_pp_degree 1 \ - --pipeline_schedule_mode "VPP" \ - --sharding "" \ - --to_static 1 \ - --num_hidden_layers 2 \ - >>${log_path}/$FUNCNAME 2>&1 + + tp_configs=( + "--tensor_parallel_config replace_with_parallel_cross_entropy" + " " + ) + for tp_config in "${tp_configs[@]}"; do + rm -rf $auto_case_out_dir + rm -rf $auto_case_log_dir + python -u -m paddle.distributed.launch \ + --gpus "0,1" \ + --log_dir $auto_case_log_dir \ + run_pretrain_auto.py \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir $auto_case_out_dir \ + --split 949,50,1 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --warmup_steps 30 \ + --max_grad_norm 0.0 \ + --learning_rate 3e-05 \ + --min_learning_rate 3e-06 \ + --max_steps 10 \ + --logging_steps 1 \ + --eval_steps 1000 \ + --save_steps 3 \ + --continue_training 0 \ + --do_train true \ + --do_eval false \ + --do_predict false \ + --disable_tqdm true \ + --skip_profile_timer true \ + --save_total_limit 2 \ + --device gpu \ + --disable_tqdm true \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --enable_auto_parallel 1 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --per_device_eval_batch_size 2 \ + --recompute false \ + --recompute_use_reentrant true \ + --recompute_granularity full \ + --pp_recompute_interval 0 \ + --bf16 0 \ + --fp16_opt_level "O2" \ + --amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" \ + --amp_master_grad false \ + --fuse_attention_ffn false \ + --fuse_attention_qkv false \ + --use_flash_attention false \ + --use_fused_rope true \ + --use_fused_rms_norm true \ + --max_seq_length 4096 \ + --sequence_parallel false \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --tensor_parallel_degree 2 \ + ${tp_config} \ + --virtual_pp_degree 1 \ + --pipeline_schedule_mode "VPP" \ + --sharding "" \ + --to_static 1 \ + --num_hidden_layers 2 \ + >>${log_path}/$FUNCNAME 2>&1 - auto_loss=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 5' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` - auto_ips=-1 - auto_mem=-1 - echo "auto result: step 5 loss=$auto_loss ips=$auto_ips mem=$auto_mem" - loss_base=10.21024895 - ips_base=-1 - mem_base=-1 - if [ $IS_A100 -ne 0 ];then - loss_base=10.27925682 - fi - check_result $FUNCNAME ${loss_base} ${auto_loss} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem} + auto_loss_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + loss_md5_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` + auto_ips_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'` + auto_mem_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'` + echo "auto result: step 2 loss=$auto_loss_2 ips=$auto_ips_2 mem=$auto_mem_2" + auto_loss_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + loss_md5_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` + auto_ips_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'` + auto_mem_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'` + echo "auto result: step 10 loss=$auto_loss_10 ips=$auto_ips_10 mem=$auto_mem_10" + if [[ $tp_config =~ "replace_with_parallel_cross_entropy" ]];then + # This optimization may result in a discrepancy in accuracy. + loss_base_2=10.53477287 + loss_base_10=9.4961338 + else + loss_base_2=10.53477192 + loss_base_10=9.4961338 + fi + auto_ips=-1 + auto_mem=-1 + ips_base=-1 + mem_base=-1 + if [ $IS_A100 -ne 0 ];then + loss_base_2=10.58283806 + loss_base_10=10.58283806 + fi + check_result $FUNCNAME ${loss_base_2} ${auto_loss_2} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem} + check_result $FUNCNAME ${loss_base_10} ${auto_loss_10} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem} + done export FLAGS_enable_fused_ffn_qkv_pass=0 echo "=========== $FUNCNAME run end ===========" } diff --git a/tests/trainer/test_auto_argparser.py b/tests/trainer/test_auto_argparser.py index 9d5b311c41e9..d10d8a786d70 100644 --- a/tests/trainer/test_auto_argparser.py +++ b/tests/trainer/test_auto_argparser.py @@ -66,6 +66,7 @@ class AutoArgparserTest(unittest.TestCase): "num_cycles": 0.5, "num_train_epochs": 3.0, "output_dir": "./checkpoints/llama2_pretrain_ckpts", + "tensor_parallel_config": "replace_with_parallel_cross_entropy", } def test_parse_cmd_lines(self):