Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine checkpoint converter #9001

Merged
merged 5 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,8 @@
)

if self.args.to_static:
if self.model_wrapped._mode is None:
self.model_wrapped.train()

Check warning on line 726 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L725-L726

Added lines #L725 - L726 were not covered by tests
model_state_dict = {
key: value
for key, value in self.model_wrapped.state_dict("param").items()
Expand Down Expand Up @@ -757,7 +759,7 @@

if self.args.auto_parallel_resume_form_hybrid_parallel:
CheckpointConverter(
resume_from_checkpoint, state_dict, parameter_to_structured_name
resume_from_checkpoint, state_dict, parameter_to_structured_name, self.args
).load_from_hybrid_parallel_checkpoint()
else:
ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH)
Expand Down
37 changes: 24 additions & 13 deletions paddlenlp/trainer/utils/ckpt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,22 @@
MODEL_WEIGHT_SUFFIX = ".pdparams"
OPTIMIZER_WEIGHT_SUFFIX = ".pdopt"
SCHEDULER_NAME = "scheduler.pdparams"
SCALAR_NAME = "scalar.pdparams"

Check warning on line 36 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L36

Added line #L36 was not covered by tests
MODEL_META_FILE_NAME = "model_meta.json"
OPTIMIZER_STATE_NAME_SUFFIX = [".moment1", ".moment2", ".beta1_pow_acc", ".beta2_pow_acc", ".master_weight"]
MODEL_STATE_FILE_MIN_SIZE = 512


class CheckpointConverter:
def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, patch_dict=None):
def __init__(

Check warning on line 43 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L43

Added line #L43 was not covered by tests
self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, trainging_args=None, patch_dict=None
):
self.use_dist = True if paddle.distributed.get_world_size() > 1 else False
self.path = hybrid_parallel_ckpt_path

if trainging_args.ignore_load_lr_and_optim:
state_dict.pop("optimizer")

Check warning on line 50 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L49-L50

Added lines #L49 - L50 were not covered by tests

self.auto_parallel_state_dict = self.flatten_state_dict(state_dict)
self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name)
model_state_global_shape = {}
Expand Down Expand Up @@ -74,9 +81,9 @@
for k, v in self.auto_parallel_state_dict.items():
if k in self.patch_dict:
del_keys.append(k)

for k in del_keys:
self.auto_parallel_state_dict[self.patch_dict[k]] = self.auto_parallel_state_dict[k]
for k in del_keys:

Check warning on line 86 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L86

Added line #L86 was not covered by tests
self.auto_parallel_state_dict.pop(k)

flags = [
Expand Down Expand Up @@ -896,25 +903,26 @@
return renamed_state_dict

def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict):

name_mapping = {}
suffix_bucket = {}
assert len(optimizer_state_dict) % len(model_state_keys) == 0
for suffix in OPTIMIZER_STATE_NAME_SUFFIX:
suffix_bucket[suffix] = []
for satte_name, satte_value in optimizer_state_dict.items():
if "moment1" in satte_name:
suffix_bucket[".moment1"].append(satte_name)
elif "moment2" in satte_name:
suffix_bucket[".moment2"].append(satte_name)
elif "beta1_pow_acc" in satte_name:
suffix_bucket[".beta1_pow_acc"].append(satte_name)
elif "beta2_pow_acc" in satte_name:
suffix_bucket[".beta2_pow_acc"].append(satte_name)
for opt_name, opt_value in optimizer_state_dict.items():
if "moment1" in opt_name:
suffix_bucket[".moment1"].append(opt_name)
elif "moment2" in opt_name:
suffix_bucket[".moment2"].append(opt_name)
elif "beta1_pow_acc" in opt_name:
suffix_bucket[".beta1_pow_acc"].append(opt_name)
elif "beta2_pow_acc" in opt_name:
suffix_bucket[".beta2_pow_acc"].append(opt_name)

Check warning on line 919 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L911-L919

Added lines #L911 - L919 were not covered by tests
else:
suffix_bucket[".master_weight"].append(satte_name)
suffix_bucket[".master_weight"].append(opt_name)

Check warning on line 921 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L921

Added line #L921 was not covered by tests

for suffix, old_names in suffix_bucket.items():
if len(old_names) == 0:
continue

Check warning on line 925 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L924-L925

Added lines #L924 - L925 were not covered by tests
assert len(old_names) == len(model_state_keys)
for i in range(len(old_names)):
name_mapping[old_names[i]] = model_state_keys[i] + suffix
Expand Down Expand Up @@ -1011,6 +1019,9 @@
cur_rank_optimizer_state_file_names.append(file_name)
if SCHEDULER_NAME in cur_rank_model_state_file_names:
cur_rank_model_state_file_names.remove(SCHEDULER_NAME)
if SCALAR_NAME in cur_rank_model_state_file_names:
cur_rank_model_state_file_names.remove(SCALAR_NAME)

Check warning on line 1023 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L1022-L1023

Added lines #L1022 - L1023 were not covered by tests

return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names

def get_distribution_rank_from_file_name(self, file_name):
Expand Down
160 changes: 160 additions & 0 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ function llama_case_list_auto() {
llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2

llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1
}

function llm_gpt_case_list_auto() {
Expand Down Expand Up @@ -1062,6 +1063,165 @@ function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() {
echo "=========== $FUNCNAME run end ==========="
}

function llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=3
export NVIDIA_TF32_OVERRIDE=0
export FLAGS_enable_pir_api=1
export FLAGS_max_inplace_grad_add=3

echo "---- run hybrid and save ckpt ----"
dy_task_name="llama_hybrid_ckpt_bs2_fp32_DP2-MP1-PP1"
dy_case_out_dir="dy_output/$dy_task_name"
dy_case_log_dir="dy_output/$dy_task_name""_log"
rm -rf $dy_case_out_dir
rm -rf $dy_case_log_dir

python -u -m paddle.distributed.launch \
--gpus "0,1" \
--log_dir $dy_case_log_dir \
../../run_pretrain.py \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $dy_case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--warmup_steps 30 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 5 \
--logging_steps 1 \
--eval_steps 1000 \
--save_steps 3 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--per_device_eval_batch_size 2 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--pp_recompute_interval 0 \
--bf16 0 \
--fp16_opt_level "O2" \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--amp_master_grad false \
--enable_linear_fused_grad_add false \
--fuse_attention_ffn true \
--fuse_attention_qkv false \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention 0 \
--use_fused_rope false \
--use_fused_rms_norm 0 \
--max_seq_length 4096 \
--sep_parallel_degree 1 \
--sequence_parallel false \
--pipeline_parallel_degree 1 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 1 \
--virtual_pp_degree 1 \
--sharding "" \
--to_static 0 \
--num_hidden_layers 2 \
>>${log_path}/$FUNCNAME 2>&1
dy_loss=`cat $dy_case_log_dir/workerlog.0 | grep 'global_step: 4' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
dy_ips=-1
dy_mem=-1
echo "hybrid result: loss=$dy_loss ips=$dy_ips mem=$dy_mem"

echo "---- run auto parallel resueme from hybrid ckpt ----"
auto_task_name="llama_auto_parallel_bs2_fp32_DP2-MP1-PP1"
auto_case_out_dir="auto_output/$auto_task_name"
auto_case_log_dir="auto_output/$auto_task_name""_log"
rm -rf $auto_case_out_dir
rm -rf $auto_case_log_dir

python -u -m paddle.distributed.launch \
--gpus "0,1" \
--log_dir $auto_case_log_dir \
run_pretrain_auto.py \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $auto_case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--warmup_steps 30 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 4 \
--logging_steps 1 \
--eval_steps 1000 \
--save_steps 1000 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--per_device_eval_batch_size 2 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--pp_recompute_interval 0 \
--bf16 0 \
--fp16_opt_level "O2" \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--amp_master_grad false \
--fuse_attention_ffn true \
--fuse_attention_qkv false \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention 0 \
--use_fused_rope false \
--use_fused_rms_norm 0 \
--max_seq_length 4096 \
--sep_parallel_degree 1 \
--sequence_parallel false \
--pipeline_parallel_degree 1 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 1 \
--virtual_pp_degree 1 \
--pipeline_schedule_mode "VPP" \
--sharding "" \
--to_static 1 \
--num_hidden_layers 2 \
--resume_from_checkpoint "dy_output/llama_hybrid_ckpt_bs2_fp32_DP2-MP1-PP1/checkpoint-3" \
--auto_parallel_resume_form_hybrid_parallel 1 \
>>${log_path}/$FUNCNAME 2>&1
auto_loss=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 4' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
auto_ips=-1
auto_mem=-1
echo "auto result: loss=$auto_loss ips=$auto_ips mem=$auto_mem"

check_result $FUNCNAME ${dy_loss} ${auto_loss} ${dy_ips} ${auto_ips} ${dy_mem} ${auto_mem}
echo "=========== $FUNCNAME run end ==========="
}

function llm_gpt_dygraph_auto_bs8_fp32_DP2() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
Expand Down
Loading