From 0ddfb042740b18736a12918a81ebe5145d6ffd37 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Mon, 4 Nov 2024 20:29:29 +0800 Subject: [PATCH] [Cherry-pick] fix unified_checkpoint to use newest model (#9362) * [Unified Checkpoint] Fix fp32 dtype for using newest paddle(#9360) --- .../trainer/unified_checkpoint/check_completion.py | 7 +------ paddlenlp/trainer/unified_checkpoint/load_dynamic.py | 7 +------ paddlenlp/trainer/unified_checkpoint/load_local.py | 8 ++------ .../unified_checkpoint/load_save_single_card.py | 9 ++------- .../trainer/unified_checkpoint/unified_checkpoint.py | 9 ++------- paddlenlp/trainer/unified_checkpoint/utils.py | 11 +---------- 6 files changed, 9 insertions(+), 42 deletions(-) diff --git a/paddlenlp/trainer/unified_checkpoint/check_completion.py b/paddlenlp/trainer/unified_checkpoint/check_completion.py index cf337c468463..8165a4542820 100644 --- a/paddlenlp/trainer/unified_checkpoint/check_completion.py +++ b/paddlenlp/trainer/unified_checkpoint/check_completion.py @@ -30,11 +30,6 @@ from paddlenlp.utils.log import logger from paddlenlp.utils.nested import flatten_list -try: - from paddle.base import core -except: - core = None - from .utils import ( get_expected_state_dict, is_sharding_split_param_mode, @@ -200,7 +195,7 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): continue - if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32: + if is_master_weights and state_dict[key].dtype == paddle.float32: continue if not is_master_weights: diff --git a/paddlenlp/trainer/unified_checkpoint/load_dynamic.py b/paddlenlp/trainer/unified_checkpoint/load_dynamic.py index 064ecacc7c3c..7f34ddc145c0 100644 --- a/paddlenlp/trainer/unified_checkpoint/load_dynamic.py +++ b/paddlenlp/trainer/unified_checkpoint/load_dynamic.py @@ -22,11 +22,6 @@ import paddle.distributed as dist from paddle.distributed import fleet -try: - from paddle.base import core -except: - core = None - from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.transformers.model_utils import _load_state_dict_into_model from paddlenlp.transformers.utils import device_guard, is_safetensors_available @@ -474,7 +469,7 @@ def check_optimizer_param(parameter): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + if model_state_dict[key_name[0]].dtype != paddle.float32: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) diff --git a/paddlenlp/trainer/unified_checkpoint/load_local.py b/paddlenlp/trainer/unified_checkpoint/load_local.py index 552289d8f383..5d16fd4ef966 100644 --- a/paddlenlp/trainer/unified_checkpoint/load_local.py +++ b/paddlenlp/trainer/unified_checkpoint/load_local.py @@ -16,13 +16,9 @@ import gc import os +import paddle from tqdm.auto import tqdm -try: - from paddle.base import core -except: - core = None - from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.transformers.model_utils import ( _load_state_dict_into_model, @@ -252,7 +248,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + if model_state_dict[key_name[0]].dtype != paddle.float32: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) diff --git a/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py b/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py index 581dc9b0da53..d481cef37749 100644 --- a/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py +++ b/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py @@ -19,11 +19,6 @@ import paddle -try: - from paddle.base import core -except: - core = None - from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.transformers.model_utils import ( _load_state_dict_into_model, @@ -120,7 +115,7 @@ def save_single_card_optimizer(model, optimizer, output_dir): fp32_weight = {} for k, v in state_dict.items(): static2struct_name_mappings[v.name] = k - if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + if master_weights is not None and v.dtype == paddle.float32: fp32_weight[k] = v # rename optimizer param @@ -226,7 +221,7 @@ def load_single_card_optimizer(model, optimizer, resume_from_checkpoint: str): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + if model_state_dict[key_name[0]].dtype != paddle.float32: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 5628874d5c30..0190529a84e3 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -19,11 +19,6 @@ import paddle from paddle.distributed import fleet -try: - from paddle.base import core -except: - core = None - from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.argparser import strtobool from paddlenlp.trainer.utils.helper import distributed_isfile @@ -281,7 +276,7 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + if model_state_dict[key_name[0]].dtype != paddle.float32: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) @@ -529,7 +524,7 @@ def unified_optimizer_into_shards( fp32_weight = {} for k, v in state_dict.items(): static2struct_name_mappings[v.name] = k - if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + if master_weights is not None and v.dtype == paddle.float32: if args.dataset_rank > 0: # deal with different dataset rank. continue fp32_weight[k] = v diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index bad8dabbafa2..74db0e20e184 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -21,11 +21,6 @@ import paddle.distributed as dist from paddle.distributed import fleet -try: - from paddle.base import core -except: - core = None - from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption from paddlenlp.trainer.utils.helper import distributed_isfile @@ -231,11 +226,7 @@ def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weight expected_keys = [] for key in list(sharded_metadata["all_optimizer_keys"]): key_name = key.split("/")[0] - if ( - is_master_weights - and key_name in model_state_dict - and model_state_dict[key_name].dtype == core.VarDesc.VarType.FP32 - ): + if is_master_weights and key_name in model_state_dict and model_state_dict[key_name].dtype == paddle.float32: continue if args.use_expert_parallel and args.data_parallel_rank > 0: