Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replace_with_parallel_cross_entropy flag #9579

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddlenlp/trainer/auto_training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from dataclasses import dataclass, field

from .trainer_utils import split_parallel_config
from .training_args import TrainingArguments
from .utils import add_start_docstrings

Expand Down Expand Up @@ -68,3 +69,7 @@ def __post_init__(self):
if self.fused_linear:
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_gemm_epilogue_pass")

mp_configs = split_parallel_config(self.tensor_parallel_config)
if "replace_with_parallel_cross_entropy" in mp_configs:
self.strategy.mp_optimization.replace_with_parallel_cross_entropy = True
3 changes: 3 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class TrainingArguments:
sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.
sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.
sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.
replace_with_parallel_cross_entropy, it replaces 'cross_entropy_with_softmax' OP with 'c_softmax_with_cross_entropy' OP in PIR static graph, which can improve model parallel performance.
pipeline_parallel_config (`str`, *optional*)(
Some additional config it highly affect the useage of pipeline parallel, we provide some option to config it.
following config is support:
Expand Down Expand Up @@ -681,6 +682,7 @@ class TrainingArguments:
"sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.\n"
"sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.\n"
"sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.\n"
"replace_with_parallel_cross_entropy, it replaces 'cross_entropy_with_softmax' OP with 'c_softmax_with_cross_entropy' OP in PIR static graph, which can improve model parallel performance.\n"
)
},
)
Expand Down Expand Up @@ -1565,6 +1567,7 @@ def is_segment_parallel_supported():
"enable_delay_scale_loss",
# "enable_mp_skip_c_identity",
# "enable_mp_fused_linear_param_grad_add",
"replace_with_parallel_cross_entropy",
]:
raise ValueError(
f"Found unknown tensor parallell config {x}, "
Expand Down
170 changes: 97 additions & 73 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -763,80 +763,104 @@ function llama_pir_auto_fuse_ffn_attention_qkv_MP2() {
auto_task_name="llama_pir_auto_fuse_ffn_attention_qkv_MP2"
auto_case_out_dir="auto_output/$auto_task_name"
auto_case_log_dir="auto_output/$auto_task_name""_log"
rm -rf $auto_case_out_dir
rm -rf $auto_case_log_dir

python -u -m paddle.distributed.launch \
--gpus "0,1" \
--log_dir $auto_case_log_dir \
run_pretrain_auto.py \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $auto_case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--warmup_steps 30 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 5 \
--logging_steps 1 \
--eval_steps 1000 \
--save_steps 3 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--per_device_eval_batch_size 2 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--pp_recompute_interval 0 \
--bf16 0 \
--fp16_opt_level "O2" \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--amp_master_grad false \
--fuse_attention_ffn false \
--fuse_attention_qkv false \
--use_flash_attention false \
--use_fused_rope true \
--use_fused_rms_norm true \
--max_seq_length 4096 \
--sequence_parallel false \
--pipeline_parallel_degree 1 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 2 \
--virtual_pp_degree 1 \
--pipeline_schedule_mode "VPP" \
--sharding "" \
--to_static 1 \
--num_hidden_layers 2 \
>>${log_path}/$FUNCNAME 2>&1

tp_configs=(
"--tensor_parallel_config replace_with_parallel_cross_entropy"
" "
)
for tp_config in "${tp_configs[@]}"; do
rm -rf $auto_case_out_dir
rm -rf $auto_case_log_dir
python -u -m paddle.distributed.launch \
--gpus "0,1" \
--log_dir $auto_case_log_dir \
run_pretrain_auto.py \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $auto_case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--warmup_steps 30 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 10 \
--logging_steps 1 \
--eval_steps 1000 \
--save_steps 3 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--per_device_eval_batch_size 2 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--pp_recompute_interval 0 \
--bf16 0 \
--fp16_opt_level "O2" \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--amp_master_grad false \
--fuse_attention_ffn false \
--fuse_attention_qkv false \
--use_flash_attention false \
--use_fused_rope true \
--use_fused_rms_norm true \
--max_seq_length 4096 \
--sequence_parallel false \
--pipeline_parallel_degree 1 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 2 \
${tp_config} \
--virtual_pp_degree 1 \
--pipeline_schedule_mode "VPP" \
--sharding "" \
--to_static 1 \
--num_hidden_layers 2 \
>>${log_path}/$FUNCNAME 2>&1

auto_loss=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 5' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
auto_ips=-1
auto_mem=-1
echo "auto result: step 5 loss=$auto_loss ips=$auto_ips mem=$auto_mem"
loss_base=10.21024895
ips_base=-1
mem_base=-1
if [ $IS_A100 -ne 0 ];then
loss_base=10.27925682
fi
check_result $FUNCNAME ${loss_base} ${auto_loss} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem}
auto_loss_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
loss_md5_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'`
auto_ips_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'`
auto_mem_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
echo "auto result: step 2 loss=$auto_loss_2 ips=$auto_ips_2 mem=$auto_mem_2"
auto_loss_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
loss_md5_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'`
auto_ips_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'`
auto_mem_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
echo "auto result: step 10 loss=$auto_loss_10 ips=$auto_ips_10 mem=$auto_mem_10"
if [[ $tp_config =~ "replace_with_parallel_cross_entropy" ]];then
# This optimization may result in a discrepancy in accuracy.
loss_base_2=10.53477287
loss_base_10=9.4961338
else
loss_base_2=10.53477192
loss_base_10=9.4961338
fi
auto_ips=-1
auto_mem=-1
ips_base=-1
mem_base=-1
if [ $IS_A100 -ne 0 ];then
loss_base_2=10.58283806
loss_base_10=10.58283806
fi
check_result $FUNCNAME ${loss_base_2} ${auto_loss_2} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem}
check_result $FUNCNAME ${loss_base_10} ${auto_loss_10} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem}
done
export FLAGS_enable_fused_ffn_qkv_pass=0
echo "=========== $FUNCNAME run end ==========="
}
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/test_auto_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class AutoArgparserTest(unittest.TestCase):
"num_cycles": 0.5,
"num_train_epochs": 3.0,
"output_dir": "./checkpoints/llama2_pretrain_ckpts",
"tensor_parallel_config": "replace_with_parallel_cross_entropy",
}

def test_parse_cmd_lines(self):
Expand Down